Skip to content

Commit 87e5d45

Browse files
switch back to chords
1 parent 26295cb commit 87e5d45

3 files changed

Lines changed: 68 additions & 50 deletions

File tree

dojo/importers/default_importer.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import logging
22

3+
from celery import chord
34
from django.core.files.uploadedfile import TemporaryUploadedFile
45
from django.core.serializers import serialize
56
from django.db.models.query_utils import Q
67
from django.urls import reverse
78

89
import dojo.jira_link.helper as jira_helper
10+
from dojo import utils
911
from dojo.decorators import we_want_async
1012
from dojo.finding import helper as finding_helper
1113
from dojo.importers.base_importer import BaseImporter, Parser
@@ -17,7 +19,6 @@
1719
Test_Import,
1820
)
1921
from dojo.notifications.helper import create_notification
20-
from dojo.tasks import wait_for_tasks_and_calculate_grade
2122
from dojo.utils import calculate_grade
2223
from dojo.validators import clean_tags
2324

@@ -158,7 +159,11 @@ def process_findings(
158159
parsed_findings: list[Finding],
159160
**kwargs: dict,
160161
) -> list[Finding]:
161-
async_task_ids = []
162+
# Progressive batching for chord execution
163+
post_processing_task_signatures = []
164+
current_batch_number = 1
165+
max_batch_size = 1024
166+
pending_grade_calculations = []
162167

163168
"""
164169
Saves findings in memory that were parsed from the scan report into the database.
@@ -248,10 +253,25 @@ def process_findings(
248253
push_to_jira=push_to_jira,
249254
)
250255

251-
# We need to call apply_async to get the result of the task so we can collect the task ID
252256
if we_want_async(async_user=self.user):
253-
result = post_processing_task_signature.apply_async()
254-
async_task_ids.append(result.id)
257+
# Collect signatures for progressive batch execution
258+
post_processing_task_signatures.append(post_processing_task_signature)
259+
260+
# Calculate current batch size: 2^batch_number, capped at max_batch_size
261+
current_batch_size = min(2 ** current_batch_number, max_batch_size)
262+
263+
# Launch chord when batch is full
264+
if len(post_processing_task_signatures) >= current_batch_size:
265+
product = self.test.engagement.product
266+
calculate_grade_signature = utils.calculate_grade_signature(product)
267+
chord_result = chord(post_processing_task_signatures)(calculate_grade_signature)
268+
pending_grade_calculations.append(chord_result)
269+
270+
logger.debug(f"Launched chord with {len(post_processing_task_signatures)} tasks (batch #{current_batch_number}, size: {current_batch_size})")
271+
272+
# Reset for next batch
273+
post_processing_task_signatures = []
274+
current_batch_number += 1
255275
else:
256276
# Execute task immediately for synchronous processing
257277
post_processing_task_signature()
@@ -270,14 +290,18 @@ def process_findings(
270290
else:
271291
jira_helper.push_to_jira(findings[0])
272292

273-
# Calculate product grade after all findings are processed
293+
# Handle any remaining signatures in the final batch
274294
product = self.test.engagement.product
275295

276-
if we_want_async(async_user=self.user) and async_task_ids:
277-
# Tasks were executed immediately during processing, now coordinate final grade calculation
278-
wait_for_tasks_and_calculate_grade.delay(async_task_ids, product.id)
296+
if we_want_async(async_user=self.user):
297+
if post_processing_task_signatures:
298+
# Launch final chord with remaining signatures
299+
calculate_grade_signature = utils.calculate_grade_signature(product)
300+
chord_result = chord(post_processing_task_signatures)(calculate_grade_signature)
301+
pending_grade_calculations.append(chord_result)
302+
logger.debug(f"Launched final chord with {len(post_processing_task_signatures)} remaining tasks")
279303

280-
# Synchronous tasks were already executed during processing, just calculate grade
304+
# Always perform an initial grading, even though it might get overwritten alter.
281305
calculate_grade(product)
282306

283307
sync = kwargs.get("sync", True)

dojo/importers/default_reimporter.py

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import logging
22

3+
from celery import chord
34
from django.core.files.uploadedfile import TemporaryUploadedFile
45
from django.core.serializers import serialize
56
from django.db.models.query_utils import Q
67

78
import dojo.finding.helper as finding_helper
89
import dojo.jira_link.helper as jira_helper
10+
from dojo import utils
911
from dojo.decorators import we_want_async
1012
from dojo.importers.base_importer import BaseImporter, Parser
1113
from dojo.importers.options import ImporterOptions
@@ -16,7 +18,6 @@
1618
Test,
1719
Test_Import,
1820
)
19-
from dojo.tasks import wait_for_tasks_and_calculate_grade
2021
from dojo.utils import calculate_grade
2122
from dojo.validators import clean_tags
2223

@@ -179,7 +180,11 @@ def process_findings(
179180
self.reactivated_items = []
180181
self.unchanged_items = []
181182
self.group_names_to_findings_dict = {}
182-
async_task_ids = []
183+
# Progressive batching for chord execution
184+
post_processing_task_signatures = []
185+
current_batch_number = 1
186+
max_batch_size = 1024
187+
pending_grade_calculations = []
183188

184189
logger.debug(f"starting reimport of {len(parsed_findings) if parsed_findings else 0} items.")
185190
logger.debug("STEP 1: looping over findings from the reimported report and trying to match them to existing findings")
@@ -254,9 +259,24 @@ def process_findings(
254259
push_to_jira=push_to_jira,
255260
)
256261
if we_want_async(async_user=self.user):
257-
# Execute task immediately and collect task ID
258-
result = post_processing_task_signature.apply_async()
259-
async_task_ids.append(result.id)
262+
# Collect signatures for progressive batch execution
263+
post_processing_task_signatures.append(post_processing_task_signature)
264+
265+
# Calculate current batch size: 2^batch_number, capped at max_batch_size
266+
current_batch_size = min(2 ** current_batch_number, max_batch_size)
267+
268+
# Launch chord when batch is full
269+
if len(post_processing_task_signatures) >= current_batch_size:
270+
product = self.test.engagement.product
271+
calculate_grade_signature = utils.calculate_grade_signature(product)
272+
chord_result = chord(post_processing_task_signatures)(calculate_grade_signature)
273+
pending_grade_calculations.append(chord_result)
274+
275+
logger.debug(f"Launched chord with {len(post_processing_task_signatures)} tasks (batch #{current_batch_number}, size: {current_batch_size})")
276+
277+
# Reset for next batch
278+
post_processing_task_signatures = []
279+
current_batch_number += 1
260280
else:
261281
# Execute task immediately for synchronous processing
262282
post_processing_task_signature()
@@ -272,12 +292,17 @@ def process_findings(
272292
# Process groups
273293
self.process_groups_for_all_findings(**kwargs)
274294

275-
# Calculate product grade once after all findings are processed
295+
# Handle any remaining signatures in the final batch
276296
product = self.test.engagement.product
277297

278-
if we_want_async(async_user=self.user) and async_task_ids:
279-
# Tasks were executed immediately during processing, now coordinate final grade calculation
280-
wait_for_tasks_and_calculate_grade.delay(async_task_ids, product.id)
298+
if we_want_async(async_user=self.user):
299+
if post_processing_task_signatures:
300+
# Launch final chord with remaining signatures
301+
calculate_grade_signature = utils.calculate_grade_signature(product)
302+
chord_result = chord(post_processing_task_signatures)(calculate_grade_signature)
303+
pending_grade_calculations.append(chord_result)
304+
logger.debug(f"Launched final chord with {len(post_processing_task_signatures)} remaining tasks")
305+
281306
# Synchronous tasks were already executed during processing, just calculate grade
282307
calculate_grade(product)
283308

dojo/tasks.py

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from datetime import date, timedelta
33

44
from auditlog.models import LogEntry
5-
from celery.result import AsyncResult
65
from celery.utils.log import get_task_logger
76
from dateutil.relativedelta import relativedelta
87
from django.conf import settings
@@ -193,36 +192,6 @@ def fix_loop_duplicates_task(*args, **kwargs):
193192
return fix_loop_duplicates()
194193

195194

196-
@app.task
197-
def wait_for_tasks_and_calculate_grade(task_ids, product_id, *args, **kwargs):
198-
"""
199-
Wait for all specified tasks to complete, then calculate product grade.
200-
This provides coordination for immediate task execution without using chord.
201-
"""
202-
logger.info(f"Waiting for {len(task_ids)} tasks to complete before calculating grade for product {product_id}")
203-
204-
# Wait for all tasks to complete
205-
results = [AsyncResult(task_id) for task_id in task_ids]
206-
207-
# This will block until all tasks are done
208-
for result in results:
209-
try:
210-
result.get(timeout=300) # 5 minute timeout per task
211-
except Exception as e:
212-
logger.warning(f"Task {result.id} failed: {e}")
213-
# Continue waiting for other tasks even if one fails
214-
215-
# All tasks completed, now calculate grade
216-
try:
217-
product = Product.objects.get(id=product_id)
218-
logger.info(f"All post-processing tasks completed, calculating grade for product {product.name}")
219-
calculate_grade(product)
220-
except Product.DoesNotExist:
221-
logger.error(f"Product {product_id} not found for grade calculation")
222-
except Exception as e:
223-
logger.error(f"Error calculating grade for product {product_id}: {e}")
224-
225-
226195
@app.task
227196
def evaluate_pro_proposition(*args, **kwargs):
228197
# Ensure we should be doing this

0 commit comments

Comments
 (0)