Skip to content

Commit bc5a18c

Browse files
committed
Refactor authorization functions to use type hints for better clarity and maintainability
1 parent ad47a1b commit bc5a18c

1 file changed

Lines changed: 23 additions & 21 deletions

File tree

dojo/authorization/authorization.py

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from django.core.exceptions import PermissionDenied
2+
from django.db.models import Model, QuerySet
23

34
from dojo.authorization.roles_permissions import (
45
Permissions,
@@ -11,6 +12,7 @@
1112
Cred_Mapping,
1213
Dojo_Group,
1314
Dojo_Group_Member,
15+
Dojo_User,
1416
Endpoint,
1517
Engagement,
1618
Finding,
@@ -30,7 +32,7 @@
3032
from dojo.request_cache import cache_for_request
3133

3234

33-
def user_has_configuration_permission(user, permission):
35+
def user_has_configuration_permission(user: Dojo_User, permission: str):
3436
if not user:
3537
return False
3638

@@ -40,7 +42,7 @@ def user_has_configuration_permission(user, permission):
4042
return user.has_perm(permission)
4143

4244

43-
def user_is_superuser_or_global_owner(user):
45+
def user_is_superuser_or_global_owner(user: Dojo_User) -> bool:
4446
"""
4547
Returns True if the user is a superuser or has a global role (directly or
4648
via group membership) whose Role.is_owner is True.
@@ -69,7 +71,7 @@ def user_is_superuser_or_global_owner(user):
6971
return False
7072

7173

72-
def user_has_permission(user, obj, permission):
74+
def user_has_permission(user: Dojo_User, obj: Model, permission: int) -> bool:
7375
if user.is_anonymous:
7476
return False
7577

@@ -229,7 +231,7 @@ def user_has_permission(user, obj, permission):
229231
raise NoAuthorizationImplementedError(msg)
230232

231233

232-
def user_has_global_permission(user, permission):
234+
def user_has_global_permission(user: Dojo_User, permission: int) -> bool:
233235
if not user:
234236
return False
235237

@@ -263,22 +265,22 @@ def user_has_global_permission(user, permission):
263265
return False
264266

265267

266-
def user_has_configuration_permission_or_403(user, permission):
268+
def user_has_configuration_permission_or_403(user: Dojo_User, permission: str) -> None:
267269
if not user_has_configuration_permission(user, permission):
268270
raise PermissionDenied
269271

270272

271-
def user_has_permission_or_403(user, obj, permission):
273+
def user_has_permission_or_403(user: Dojo_User, obj: Model, permission: int) -> None:
272274
if not user_has_permission(user, obj, permission):
273275
raise PermissionDenied
274276

275277

276-
def user_has_global_permission_or_403(user, permission):
278+
def user_has_global_permission_or_403(user: Dojo_User, permission: int) -> None:
277279
if not user_has_global_permission(user, permission):
278280
raise PermissionDenied
279281

280282

281-
def get_roles_for_permission(permission):
283+
def get_roles_for_permission(permission: int) -> set[int]:
282284
if not Permissions.has_value(permission):
283285
msg = f"Permission {permission} does not exist"
284286
raise PermissionDoesNotExistError(msg)
@@ -291,7 +293,7 @@ def get_roles_for_permission(permission):
291293
return roles_for_permissions
292294

293295

294-
def role_has_permission(role, permission):
296+
def role_has_permission(role: int, permission: int) -> bool:
295297
if role is None:
296298
return False
297299
if not Roles.has_value(role):
@@ -304,7 +306,7 @@ def role_has_permission(role, permission):
304306
return permission in permissions
305307

306308

307-
def role_has_global_permission(role, permission):
309+
def role_has_global_permission(role: int, permission: int) -> bool:
308310
if role is None:
309311
return False
310312
if not Roles.has_value(role):
@@ -332,12 +334,12 @@ def __init__(self, message):
332334
self.message = message
333335

334336

335-
def get_product_member(user, product):
337+
def get_product_member(user: Dojo_User, product: Product) -> Product_Member | None:
336338
return get_product_member_dict(user).get(product.id)
337339

338340

339341
@cache_for_request
340-
def get_product_member_dict(user):
342+
def get_product_member_dict(user: Dojo_User) -> dict[int, Product_Member]:
341343
pm_dict = {}
342344
for product_member in (
343345
Product_Member.objects.select_related("product")
@@ -348,12 +350,12 @@ def get_product_member_dict(user):
348350
return pm_dict
349351

350352

351-
def get_product_type_member(user, product_type):
353+
def get_product_type_member(user: Dojo_User, product_type: Product_Type) -> Product_Type_Member | None:
352354
return get_product_type_member_dict(user).get(product_type.id)
353355

354356

355357
@cache_for_request
356-
def get_product_type_member_dict(user):
358+
def get_product_type_member_dict(user: Dojo_User) -> dict[int, Product_Type_Member]:
357359
ptm_dict = {}
358360
for product_type_member in (
359361
Product_Type_Member.objects.select_related("product_type")
@@ -364,12 +366,12 @@ def get_product_type_member_dict(user):
364366
return ptm_dict
365367

366368

367-
def get_product_groups(user, product):
369+
def get_product_groups(user: Dojo_User, product: Product) -> list[Product_Group]:
368370
return get_product_groups_dict(user).get(product.id, [])
369371

370372

371373
@cache_for_request
372-
def get_product_groups_dict(user):
374+
def get_product_groups_dict(user: Dojo_User) -> dict[int, list[Product_Group]]:
373375
pg_dict = {}
374376
for product_group in (
375377
Product_Group.objects.select_related("product")
@@ -382,12 +384,12 @@ def get_product_groups_dict(user):
382384
return pg_dict
383385

384386

385-
def get_product_type_groups(user, product_type):
387+
def get_product_type_groups(user: Dojo_User, product_type: Product_Type) -> list[Product_Type_Group]:
386388
return get_product_type_groups_dict(user).get(product_type.id, [])
387389

388390

389391
@cache_for_request
390-
def get_product_type_groups_dict(user):
392+
def get_product_type_groups_dict(user: Dojo_User) -> dict[int, list[Product_Type_Group]]:
391393
pgt_dict = {}
392394
for product_type_group in (
393395
Product_Type_Group.objects.select_related("product_type")
@@ -404,16 +406,16 @@ def get_product_type_groups_dict(user):
404406

405407

406408
@cache_for_request
407-
def get_groups(user):
409+
def get_groups(user: Dojo_User) -> QuerySet[Dojo_Group]:
408410
return Dojo_Group.objects.select_related("global_role").filter(users=user)
409411

410412

411-
def get_group_member(user, group):
413+
def get_group_member(user: Dojo_User, group: Dojo_Group) -> dict[int, Dojo_Group_Member]:
412414
return get_group_members_dict(user).get(group.id)
413415

414416

415417
@cache_for_request
416-
def get_group_members_dict(user):
418+
def get_group_members_dict(user: Dojo_User) -> dict[int, Dojo_Group_Member]:
417419
gu_dict = {}
418420
for group_member in (
419421
Dojo_Group_Member.objects.select_related("group")

0 commit comments

Comments
 (0)