Skip to content

Commit cb2d0e3

Browse files
refactor to await results
1 parent 72e95c3 commit cb2d0e3

3 files changed

Lines changed: 59 additions & 28 deletions

File tree

dojo/finding/helper.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,8 @@ def post_process_finding_save_signature(finding, dedupe_option=True, rules_optio
365365
issue_updater_option=True, push_to_jira=False, user=None, *args, **kwargs): # noqa: FBT002 - this is bit hard to fix nice have this universally fixed
366366
"""
367367
Returns a task signature for post-processing a finding. This is useful for creating task signatures
368-
that can be used in chords or groups.
368+
that can be used in chords or groups or to await results. We need this extra method because of our dojo_async decorator.
369+
If we use more of these celery features, we should probably move away from that decorator.
369370
"""
370371
return post_process_finding_save_internal(finding, dedupe_option, rules_option, product_grading_option,
371372
issue_updater_option, push_to_jira, user, *args, **kwargs)

dojo/importers/default_importer.py

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import dojo.jira_link.helper as jira_helper
99
from dojo.decorators import we_want_async
10+
from dojo.finding import helper as finding_helper
1011
from dojo.importers.base_importer import BaseImporter, Parser
1112
from dojo.importers.options import ImporterOptions
1213
from dojo.models import (
@@ -16,6 +17,8 @@
1617
Test_Import,
1718
)
1819
from dojo.notifications.helper import create_notification
20+
from dojo.tasks import wait_for_tasks_and_calculate_grade
21+
from dojo.utils import calculate_grade
1922
from dojo.validators import clean_tags
2023

2124
logger = logging.getLogger(__name__)
@@ -155,11 +158,7 @@ def process_findings(
155158
parsed_findings: list[Finding],
156159
**kwargs: dict,
157160
) -> list[Finding]:
158-
from celery import chord
159-
160-
from dojo.finding import helper as finding_helper
161-
from dojo.utils import calculate_grade, calculate_grade_signature
162-
post_processing_task_signatures = []
161+
async_task_ids = []
163162

164163
"""
165164
Saves findings in memory that were parsed from the scan report into the database.
@@ -189,7 +188,7 @@ def process_findings(
189188
unsaved_finding.reporter = self.user
190189
unsaved_finding.last_reviewed_by = self.user
191190
unsaved_finding.last_reviewed = self.now
192-
logger.debug("process_parsed_findings: unique_id_from_tool: %s, hash_code: %s, active from report: %s, verified from report: %s", unsaved_finding.unique_id_from_tool, unsaved_finding.hash_code, unsaved_finding.active, unsaved_finding.verified)
191+
logger.debug("process_parsed_finding: unique_id_from_tool: %s, hash_code: %s, active from report: %s, verified from report: %s", unsaved_finding.unique_id_from_tool, unsaved_finding.hash_code, unsaved_finding.active, unsaved_finding.verified)
193192
# indicates an override. Otherwise, do not change the value of unsaved_finding.active
194193
if self.active is not None:
195194
unsaved_finding.active = self.active
@@ -238,20 +237,25 @@ def process_findings(
238237
new_findings.append(finding)
239238
# all data is already saved on the finding, we only need to trigger post processing
240239

241-
# Collect finding for parallel processing - we'll process them all at once after the loop
240+
# We create a signature for the post processing task so we can decide to apply it async or sync
242241
push_to_jira = self.push_to_jira and (not self.findings_groups_enabled or not self.group_by)
243-
# Always create signatures - we'll execute them sync or async later
244-
post_processing_task_signatures.append(
245-
finding_helper.post_process_finding_save_signature(
246-
finding,
247-
dedupe_option=True,
248-
rules_option=True,
249-
product_grading_option=False,
250-
issue_updater_option=True,
251-
push_to_jira=push_to_jira,
252-
),
242+
post_processing_task_signature = finding_helper.post_process_finding_save_signature(
243+
finding,
244+
dedupe_option=True,
245+
rules_option=True,
246+
product_grading_option=False,
247+
issue_updater_option=True,
248+
push_to_jira=push_to_jira,
253249
)
254250

251+
# We need to call apply_async to get the result of the task so we can collect the task ID
252+
if we_want_async(async_user=self.user):
253+
result = post_processing_task_signature.apply_async()
254+
async_task_ids.append(result.id)
255+
else:
256+
# Execute task immediately for synchronous processing
257+
post_processing_task_signature()
258+
255259
for (group_name, findings) in group_names_to_findings_dict.items():
256260
finding_helper.add_findings_to_auto_group(
257261
group_name,
@@ -268,17 +272,12 @@ def process_findings(
268272

269273
# Calculate product grade after all findings are processed
270274
product = self.test.engagement.product
271-
if post_processing_task_signatures:
272-
# If we have async tasks, use chord to wait for them before calculating grade
273-
if we_want_async(async_user=self.user):
274-
# Run the chord asynchronously and after completing post processing tasks, calculate grade ONCE
275-
chord(post_processing_task_signatures)(calculate_grade_signature(product))
276-
else:
277-
# Execute each task synchronously
278-
for task_sig in post_processing_task_signatures:
279-
task_sig()
280275

281-
# Calculate grade, which can be prelimary calculated before the async tasks have finished
276+
if we_want_async(async_user=self.user) and async_task_ids:
277+
# Tasks were executed immediately during processing, now coordinate final grade calculation
278+
wait_for_tasks_and_calculate_grade.delay(async_task_ids, product.id)
279+
280+
# Synchronous tasks were already executed during processing, just calculate grade
282281
calculate_grade(product)
283282

284283
sync = kwargs.get("sync", True)

dojo/tasks.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from datetime import date, timedelta
33

44
from auditlog.models import LogEntry
5+
from celery.result import AsyncResult
56
from celery.utils.log import get_task_logger
67
from dateutil.relativedelta import relativedelta
78
from django.conf import settings
@@ -192,6 +193,36 @@ def fix_loop_duplicates_task(*args, **kwargs):
192193
return fix_loop_duplicates()
193194

194195

196+
@app.task
197+
def wait_for_tasks_and_calculate_grade(task_ids, product_id, *args, **kwargs):
198+
"""
199+
Wait for all specified tasks to complete, then calculate product grade.
200+
This provides coordination for immediate task execution without using chord.
201+
"""
202+
logger.info(f"Waiting for {len(task_ids)} tasks to complete before calculating grade for product {product_id}")
203+
204+
# Wait for all tasks to complete
205+
results = [AsyncResult(task_id) for task_id in task_ids]
206+
207+
# This will block until all tasks are done
208+
for result in results:
209+
try:
210+
result.get(timeout=300) # 5 minute timeout per task
211+
except Exception as e:
212+
logger.warning(f"Task {result.id} failed: {e}")
213+
# Continue waiting for other tasks even if one fails
214+
215+
# All tasks completed, now calculate grade
216+
try:
217+
product = Product.objects.get(id=product_id)
218+
logger.info(f"All post-processing tasks completed, calculating grade for product {product.name}")
219+
calculate_grade(product)
220+
except Product.DoesNotExist:
221+
logger.error(f"Product {product_id} not found for grade calculation")
222+
except Exception as e:
223+
logger.error(f"Error calculating grade for product {product_id}: {e}")
224+
225+
195226
@app.task
196227
def evaluate_pro_proposition(*args, **kwargs):
197228
# Ensure we should be doing this

0 commit comments

Comments
 (0)