Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 118 additions & 102 deletions pandajedi/jediorder/PostProcessor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
import datetime
"""
Post-processing worker for JEDI tasks.

PostProcessor is a long-running daemon that periodically picks up tasks ready
to be finished and dispatches them to a pool of PostProcessorThread workers.
Each worker calls the VO/label-specific post-processor implementation
(doPostProcess) and, if successful, the final-procedure hook (doFinalProcedure).
"""

import os
import socket
import sys
Expand All @@ -18,149 +26,157 @@
logger = PandaLogger().getLogger(__name__.split(".")[-1])


# worker class to do post-processing
class PostProcessor(JediKnight, FactoryBase):
# constructor
"""
Daemon that drives post-processing for finished JEDI tasks.

Inherits scheduling and communication from JediKnight and VO/label-specific
implementation instantiation from FactoryBase. On each loop iteration it
calls prepareTasksToBeFinished_JEDI to lock eligible tasks, then hands them
off to a pool of PostProcessorThread workers.
"""

def __init__(self, commuChannel, taskBufferIF, ddmIF, vos, prodSourceLabels):
self.vos = self.parseInit(vos)
self.prodSourceLabels = self.parseInit(prodSourceLabels)
self.pid = f"{socket.getfqdn().split('.')[0]}-{os.getpid()}-post"
JediKnight.__init__(self, commuChannel, taskBufferIF, ddmIF, logger)
FactoryBase.__init__(self, self.vos, self.prodSourceLabels, logger, jedi_config.postprocessor.modConfig)

# main
def start(self):
# start base classes
"""Run the main post-processing loop, cycling every 60 seconds."""
JediKnight.start(self)
FactoryBase.initializeMods(self, self.taskBufferIF, self.ddmIF)
# go into main loop

while True:
startTime = naive_utcnow()
start_time = naive_utcnow()
try:
# get logger
tmpLog = MsgWrapper(logger)
tmpLog.info("start")
# loop over all vos
tmp_log = MsgWrapper(logger)
tmp_log.info("start")

for vo in self.vos:
# loop over all sourceLabels
for prodSourceLabel in self.prodSourceLabels:
# prepare tasks to be finished
tmpLog.info(f"preparing tasks to be finished for vo={vo} label={prodSourceLabel}")
tmp_ret_list = self.taskBufferIF.prepareTasksToBeFinished_JEDI(vo, prodSourceLabel, jedi_config.postprocessor.nTasks, pid=self.pid)
if tmp_ret_list is None:
# failed
tmpLog.error("failed to prepare tasks")
# get tasks to be finished
tmpLog.info("getting tasks to be finished")
tmpList = self.taskBufferIF.getTasksToBeFinished_JEDI(
vo, prodSourceLabel, self.pid, jedi_config.postprocessor.nTasks, target_tasks=tmp_ret_list
for prod_source_label in self.prodSourceLabels:
tmp_log.info(f"preparing tasks to be finished for vo={vo} label={prod_source_label}")
target_tasks = self.taskBufferIF.prepareTasksToBeFinished_JEDI(vo, prod_source_label, jedi_config.postprocessor.nTasks, pid=self.pid)
if target_tasks is None:
tmp_log.error("failed to prepare tasks")

tmp_log.info("getting tasks to be finished")
task_list = self.taskBufferIF.getTasksToBeFinished_JEDI(
vo, prod_source_label, self.pid, jedi_config.postprocessor.nTasks, target_tasks=target_tasks
)
if tmpList is None:
# failed
tmpLog.error("failed to get tasks to be finished")
else:
tmpLog.info(f"got {len(tmpList)} tasks")
# put to a locked list
taskList = ListWithLock(tmpList)
# make thread pool
threadPool = ThreadPool()
# make workers
nWorker = jedi_config.postprocessor.nWorkers
for iWorker in range(nWorker):
thr = PostProcessorThread(taskList, threadPool, self.taskBufferIF, self.ddmIF, self)
thr.start()
# join
threadPool.join()
tmpLog.info("done")
if task_list is None:
tmp_log.error("failed to get tasks to be finished")
continue

tmp_log.info(f"got {len(task_list)} tasks")
locked_list = ListWithLock(task_list)
thread_pool = ThreadPool()
for _ in range(jedi_config.postprocessor.nWorkers):
thr = PostProcessorThread(locked_list, thread_pool, self.taskBufferIF, self.ddmIF, self)
thr.start()
thread_pool.join()

tmp_log.info("done")
except Exception:
errtype, errvalue = sys.exc_info()[:2]
tmpLog.error(f"failed in {self.__class__.__name__}.start() with {errtype.__name__} {errvalue}")
# sleep if needed
loopCycle = 60
timeDelta = naive_utcnow() - startTime
sleepPeriod = loopCycle - timeDelta.seconds
if sleepPeriod > 0:
time.sleep(sleepPeriod)
err_type, err_value = sys.exc_info()[:2]
tmp_log.error(f"failed in {self.__class__.__name__}.start() with {err_type.__name__} {err_value}")
Comment on lines 81 to +83

# sleep for the remainder of the 60-second cycle
loop_cycle = 60
elapsed = naive_utcnow() - start_time
sleep_period = loop_cycle - elapsed.seconds
if sleep_period > 0:
time.sleep(sleep_period)
Comment on lines +86 to +90


# thread for real worker
class PostProcessorThread(WorkerThread):
# constructor
"""
Worker thread that post-processes a batch of JEDI tasks.

Instantiated by PostProcessor.start() for each worker slot. Pulls tasks
from a shared locked list and calls post_process_tasks() until the list is
exhausted.
"""

def __init__(self, taskList, threadPool, taskbufferIF, ddmIF, implFactory):
# initialize worker with no semaphore
WorkerThread.__init__(self, None, threadPool, logger)
# attributes
self.taskList = taskList
self.taskBufferIF = taskbufferIF
self.ddmIF = ddmIF
self.implFactory = implFactory

# post process tasks
def post_process_tasks(self, task_list):
for taskSpec in task_list:
# make logger
tmpLog = MsgWrapper(self.logger, f"<jediTaskID={taskSpec.jediTaskID}>")
tmpLog.info("start")
tmpStat = Interaction.SC_SUCCEEDED
# get impl
impl = self.implFactory.instantiateImpl(taskSpec.vo, taskSpec.prodSourceLabel, None, self.taskBufferIF, self.ddmIF)
"""
Run post-processing and final-procedure for each task in task_list.

Outcome per task:
- SC_FATAL or SC_FAILED on a terminal task status → mark as broken.
- SC_FAILED on a non-terminal status → record transient error and skip
the final procedure.
- SC_SUCCEEDED → call doFinalProcedure.
"""
for task_spec in task_list:
tmp_log = MsgWrapper(self.logger, f"<jediTaskID={task_spec.jediTaskID}>")
tmp_log.info("start")
tmp_stat = Interaction.SC_SUCCEEDED

# instantiate the VO/label-specific post-processor
impl = self.implFactory.instantiateImpl(task_spec.vo, task_spec.prodSourceLabel, None, self.taskBufferIF, self.ddmIF)
if impl is None:
# post processor is undefined
tmpLog.error(f"post-processor is undefined for vo={taskSpec.vo} sourceLabel={taskSpec.prodSourceLabel}")
tmpStat = Interaction.SC_FATAL
# execute
if tmpStat == Interaction.SC_SUCCEEDED:
tmpLog.info(f"post-process with {impl.__class__.__name__}")
tmp_log.error(f"post-processor is undefined for vo={task_spec.vo} sourceLabel={task_spec.prodSourceLabel}")
tmp_stat = Interaction.SC_FATAL

# run post-processing
if tmp_stat == Interaction.SC_SUCCEEDED:
tmp_log.info(f"post-process with {impl.__class__.__name__}")
try:
tmpStat = impl.doPostProcess(taskSpec, tmpLog)
tmp_stat = impl.doPostProcess(task_spec, tmp_log)
except Exception as e:
tmpLog.error(f"post-process failed with {str(e)}")
tmpStat = Interaction.SC_FATAL
# done
if tmpStat == Interaction.SC_FATAL or (tmpStat == Interaction.SC_FAILED and taskSpec.status in ["toabort", "tobroken"]):
# task is broken
tmpErrStr = "post-process permanently failed"
tmpLog.error(tmpErrStr)
taskSpec.status = "broken"
taskSpec.setErrDiag(tmpErrStr)
taskSpec.lockedBy = None
self.taskBufferIF.updateTask_JEDI(taskSpec, {"jediTaskID": taskSpec.jediTaskID})
elif tmpStat == Interaction.SC_FAILED:
tmpErrStr = "post-processing temporarily failed"
taskSpec.setErrDiag(tmpErrStr, True)
self.taskBufferIF.updateTask_JEDI(taskSpec, {"jediTaskID": taskSpec.jediTaskID})
tmpLog.info(f"set task_status={taskSpec.status} since {taskSpec.errorDialog}")
tmpLog.info("done")
tmp_log.error(f"post-process failed with {str(e)}")
tmp_stat = Interaction.SC_FATAL

# handle permanent failure
if tmp_stat == Interaction.SC_FATAL or (tmp_stat == Interaction.SC_FAILED and task_spec.status in ("toabort", "tobroken")):
err_str = "post-process permanently failed"
tmp_log.error(err_str)
task_spec.status = "broken"
task_spec.setErrDiag(err_str)
task_spec.lockedBy = None
self.taskBufferIF.updateTask_JEDI(task_spec, {"jediTaskID": task_spec.jediTaskID})

# handle transient failure — skip final procedure
elif tmp_stat == Interaction.SC_FAILED:
err_str = "post-processing temporarily failed"
task_spec.setErrDiag(err_str, True)
self.taskBufferIF.updateTask_JEDI(task_spec, {"jediTaskID": task_spec.jediTaskID})
tmp_log.info(f"set task_status={task_spec.status} since {task_spec.errorDialog}")
tmp_log.info("done")
continue
# final procedure

# run final procedure depending on prodsourcelabel (e.g. email notifications, manage output datasets, etc.)
try:
impl.doFinalProcedure(taskSpec, tmpLog)
impl.doFinalProcedure(task_spec, tmp_log)
except Exception as e:
tmpLog.error(f"final procedure failed with {str(e)}")
# done
tmpLog.info("done")
tmp_log.error(f"final procedure failed with {str(e)}")

tmp_log.info("done")

# main
def runImpl(self):
"""Pull batches of tasks from the shared list and post-process them."""
while True:
try:
# get a part of list
nTasks = 10
taskList = self.taskList.get(nTasks)
# no more datasets
if len(taskList) == 0:
task_list = self.taskList.get(10)
if not task_list:
self.logger.debug(f"{self.__class__.__name__} terminating since no more items")
return
# post process tasks
self.post_process_tasks(taskList)
self.post_process_tasks(task_list)
except Exception:
errtype, errvalue = sys.exc_info()[:2]
logger.error(f"{self.__class__.__name__} failed in runImpl() with {errtype.__name__}:{errvalue}")


# launch
err_type, err_value = sys.exc_info()[:2]
logger.error(f"{self.__class__.__name__} failed in runImpl() with {err_type.__name__}:{err_value}")


def launcher(commuChannel, taskBufferIF, ddmIF, vos=None, prodSourceLabels=None):
"""Entry point used by the JEDI daemon infrastructure to start the PostProcessor."""
p = PostProcessor(commuChannel, taskBufferIF, ddmIF, vos, prodSourceLabels)
p.start()
Loading
Loading