Skip to content

Commit 60fa754

Browse files
author
Kasper
committed
Added possibility to pass extra args to inference function
1 parent 113ee12 commit 60fa754

2 files changed

Lines changed: 7 additions & 4 deletions

File tree

patched_yolo_infer/elements/CropElement.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@ def __init__(
3232
self.detected_masks_real = None # List of np arrays containing masks in case of yolo-seg with the size of source_image_resized or source_image
3333
self.detected_polygons_real = None # List of polygons points in case of using memory optimaze in values from source_image_resized or source_image
3434

35-
def calculate_inference(self, model, imgsz=640, conf=0.35, iou=0.7, segment=False, classes_list=None, memory_optimize=False):
35+
def calculate_inference(self, model, imgsz=640, conf=0.35, iou=0.7, segment=False, classes_list=None, memory_optimize=False, extra_args=None):
3636

3737
# Perform inference
38-
predictions = model.predict(self.crop, imgsz=imgsz, conf=conf, iou=iou, classes=classes_list, verbose=False)
38+
predictions = model.predict(self.crop, imgsz=imgsz, conf=conf, iou=iou, classes=classes_list, verbose=False, **extra_args)
3939

4040
pred = predictions[0]
4141

patched_yolo_infer/nodes/MakeCropsDetectThem.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@ def __init__(
6868
show_crops=False,
6969
resize_initial_size=False,
7070
model=None,
71-
memory_optimize=True
71+
memory_optimize=True,
72+
inference_extra_args=None,
7273
) -> None:
7374
if model is None:
7475
self.model = YOLO(model_path) # Load the model from the specified path
@@ -89,6 +90,7 @@ def __init__(
8990
self.resize_initial_size = resize_initial_size # slow operation !
9091
self.memory_optimize = memory_optimize # memory opimization option for segmentation
9192
self.class_names_dict = self.model.names
93+
self.inference_extra_args = inference_extra_args
9294

9395
self.crops = self.get_crops_xy(
9496
self.image,
@@ -199,7 +201,8 @@ def _detect_objects(self):
199201
iou=self.iou,
200202
segment=self.segment,
201203
classes_list=self.classes_list,
202-
memory_optimize=self.memory_optimize
204+
memory_optimize=self.memory_optimize,
205+
extra_args=self.inference_extra_args
203206
)
204207
crop.calculate_real_values()
205208
if self.resize_initial_size:

0 commit comments

Comments
 (0)