From 962ac55026ae8194c9038c12a7d59e27e60f6f52 Mon Sep 17 00:00:00 2001 From: David Whittaker Date: Fri, 30 May 2025 11:02:03 -0700 Subject: [PATCH] fix(service): get_by_name_or_default no longer works if name wasn't found --- src/dispatch/case/priority/service.py | 21 +++++---- src/dispatch/case/severity/service.py | 47 ++++++++++--------- src/dispatch/case/type/service.py | 12 ++--- src/dispatch/incident/priority/service.py | 43 +++++++++-------- src/dispatch/incident/severity/service.py | 36 +++++++------- src/dispatch/incident/type/service.py | 15 +++--- src/dispatch/organization/service.py | 26 +++++----- src/dispatch/project/service.py | 24 +++++----- tests/case_type/test_case_type_service.py | 29 ++++++++++++ .../test_incident_priority_service.py | 34 ++++++++++++++ .../test_incident_type_service.py | 29 ++++++++++++ .../organization/test_organization_service.py | 24 ++++++++++ tests/project/test_project_service.py | 26 +++++++++- tests/signal/test_signal_service.py | 6 +-- 14 files changed, 262 insertions(+), 110 deletions(-) diff --git a/src/dispatch/case/priority/service.py b/src/dispatch/case/priority/service.py index 8333009b2244..76dc99474110 100644 --- a/src/dispatch/case/priority/service.py +++ b/src/dispatch/case/priority/service.py @@ -41,7 +41,7 @@ def get_default_or_raise(*, db_session, project_id: int) -> CasePriority: "input": None, "ctx": {"error": ValueError("No default case priority defined.")}, } - ] + ], ) return case_priority @@ -73,9 +73,11 @@ def get_by_name_or_raise( "loc": ("case_priority",), "input": case_priority_in.name, "msg": "Value error, Case priority not found.", - "ctx": {"error": ValueError(f"Case priority not found: {case_priority_in.name}")} + "ctx": { + "error": ValueError(f"Case priority not found: {case_priority_in.name}") + }, } - ] + ], ) return case_priority @@ -85,13 +87,12 @@ def get_by_name_or_default( *, db_session, project_id: int, case_priority_in=CasePriorityRead ) -> CasePriority: """Returns a case priority based on a name or the default if not specified.""" - if case_priority_in: - if case_priority_in.name: - return get_by_name_or_raise( - db_session=db_session, - project_id=project_id, - case_priority_in=case_priority_in, - ) + if case_priority_in and case_priority_in.name: + case_priority = get_by_name( + db_session=db_session, project_id=project_id, name=case_priority_in.name + ) + if case_priority: + return case_priority return get_default_or_raise(db_session=db_session, project_id=project_id) diff --git a/src/dispatch/case/severity/service.py b/src/dispatch/case/severity/service.py index caf932b6b3e4..b5420f2937d2 100644 --- a/src/dispatch/case/severity/service.py +++ b/src/dispatch/case/severity/service.py @@ -32,13 +32,15 @@ def get_default_or_raise(*, db_session, project_id: int) -> CaseSeverity: case_severity = get_default(db_session=db_session, project_id=project_id) if not case_severity: - raise ValidationError([ - { - "loc": ("case_severity",), - "msg": "No default case severity defined.", - "type": "value_error", - } - ]) + raise ValidationError( + [ + { + "loc": ("case_severity",), + "msg": "No default case severity defined.", + "type": "value_error", + } + ] + ) return case_severity @@ -61,14 +63,16 @@ def get_by_name_or_raise( ) if not case_severity: - raise ValidationError([ - { - "loc": ("case_severity",), - "msg": "Case severity not found.", - "type": "value_error", - "case_severity": case_severity_in.name, - } - ]) + raise ValidationError( + [ + { + "loc": ("case_severity",), + "msg": "Case severity not found.", + "type": "value_error", + "case_severity": case_severity_in.name, + } + ] + ) return case_severity @@ -77,13 +81,12 @@ def get_by_name_or_default( *, db_session, project_id: int, case_severity_in=CaseSeverityRead ) -> CaseSeverity: """Returns a case severity based on a name or the default if not specified.""" - if case_severity_in: - if case_severity_in.name: - return get_by_name_or_raise( - db_session=db_session, - project_id=project_id, - case_severity_in=case_severity_in, - ) + if case_severity_in and case_severity_in.name: + case_severity = get_by_name( + db_session=db_session, project_id=project_id, name=case_severity_in.name + ) + if case_severity: + return case_severity return get_default_or_raise(db_session=db_session, project_id=project_id) diff --git a/src/dispatch/case/type/service.py b/src/dispatch/case/type/service.py index 5f7decedd17b..4468a5027db4 100644 --- a/src/dispatch/case/type/service.py +++ b/src/dispatch/case/type/service.py @@ -1,4 +1,3 @@ - from sqlalchemy.sql.expression import true from dispatch.case import service as case_service @@ -58,11 +57,12 @@ def get_by_name_or_raise(*, db_session, project_id: int, case_type_in=CaseTypeRe def get_by_name_or_default(*, db_session, project_id: int, case_type_in=CaseTypeRead) -> CaseType: """Returns a case type based on a name or the default if not specified.""" - if case_type_in: - if case_type_in.name: - return get_by_name_or_raise( - db_session=db_session, project_id=project_id, case_type_in=case_type_in - ) + if case_type_in and case_type_in.name: + case_type = get_by_name( + db_session=db_session, project_id=project_id, name=case_type_in.name + ) + if case_type: + return case_type return get_default_or_raise(db_session=db_session, project_id=project_id) diff --git a/src/dispatch/incident/priority/service.py b/src/dispatch/incident/priority/service.py index b012fe582e97..66c435b72497 100644 --- a/src/dispatch/incident/priority/service.py +++ b/src/dispatch/incident/priority/service.py @@ -36,12 +36,14 @@ def get_default_or_raise(*, db_session, project_id: int) -> IncidentPriority: incident_priority = get_default(db_session=db_session, project_id=project_id) if not incident_priority: - raise ValidationError([ - { - "msg": "No default incident priority defined.", - "loc": "incident_priority", - } - ]) + raise ValidationError( + [ + { + "msg": "No default incident priority defined.", + "loc": "incident_priority", + } + ] + ) return incident_priority @@ -64,13 +66,15 @@ def get_by_name_or_raise( ) if not incident_priority: - raise ValidationError([ - { - "msg": "Incident priority not found.", - "loc": "incident_priority", - "incident_priority": incident_priority_in.name, - } - ]) + raise ValidationError( + [ + { + "msg": "Incident priority not found.", + "loc": "incident_priority", + "incident_priority": incident_priority_in.name, + } + ] + ) return incident_priority @@ -79,13 +83,12 @@ def get_by_name_or_default( *, db_session, project_id: int, incident_priority_in=IncidentPriorityRead ) -> IncidentPriority: """Returns a incident priority based on a name or the default if not specified.""" - if incident_priority_in: - if incident_priority_in.name: - return get_by_name_or_raise( - db_session=db_session, - project_id=project_id, - incident_priority_in=incident_priority_in, - ) + if incident_priority_in and incident_priority_in.name: + incident_priority = get_by_name( + db_session=db_session, project_id=project_id, name=incident_priority_in.name + ) + if incident_priority: + return incident_priority return get_default_or_raise(db_session=db_session, project_id=project_id) diff --git a/src/dispatch/incident/severity/service.py b/src/dispatch/incident/severity/service.py index f8fc37d6953d..7c09c2200e7d 100644 --- a/src/dispatch/incident/severity/service.py +++ b/src/dispatch/incident/severity/service.py @@ -44,9 +44,9 @@ def get_default_or_raise(*, db_session, project_id: int) -> IncidentSeverity: "loc": ("incident_severity",), "input": None, "msg": "No default incident severity defined.", - "ctx": {"error": ValueError("No default incident severity defined.")} + "ctx": {"error": ValueError("No default incident severity defined.")}, } - ] + ], ) return incident_severity @@ -71,14 +71,16 @@ def get_by_name_or_raise( ) if not incident_severity: - raise ValidationError([ - { - "msg": "Incident severity not found.", - "loc": ("incident_severity",), - "type": "value_error.not_found", - "incident_severity": incident_severity_in.name, - } - ]) + raise ValidationError( + [ + { + "msg": "Incident severity not found.", + "loc": ("incident_severity",), + "type": "value_error.not_found", + "incident_severity": incident_severity_in.name, + } + ] + ) return incident_severity @@ -87,14 +89,12 @@ def get_by_name_or_default( *, db_session, project_id: int, incident_severity_in=IncidentSeverityRead ) -> IncidentSeverity: """Returns an incident severity based on a name or the default if not specified.""" - if incident_severity_in: - if incident_severity_in.name: - return get_by_name_or_raise( - db_session=db_session, - project_id=project_id, - incident_severity_in=incident_severity_in, - ) - + if incident_severity_in and incident_severity_in.name: + incident_severity = get_by_name( + db_session=db_session, project_id=project_id, name=incident_severity_in.name + ) + if incident_severity: + return incident_severity return get_default_or_raise(db_session=db_session, project_id=project_id) diff --git a/src/dispatch/incident/type/service.py b/src/dispatch/incident/type/service.py index b27bc46f6a00..83f7631146c1 100644 --- a/src/dispatch/incident/type/service.py +++ b/src/dispatch/incident/type/service.py @@ -41,7 +41,7 @@ def get_default_or_raise(*, db_session, project_id: int) -> IncidentType: "input": None, "ctx": {"error": ValueError("No default incident type defined.")}, } - ] + ], ) return incident_type @@ -74,7 +74,7 @@ def get_by_name_or_raise( "input": incident_type_in.name, "ctx": {"error": ValueError("Incident type not found.")}, } - ] + ], ) return incident_type @@ -84,11 +84,12 @@ def get_by_name_or_default( *, db_session, project_id: int, incident_type_in=IncidentTypeRead ) -> IncidentType: """Returns a incident_type based on a name or the default if not specified.""" - if incident_type_in: - if incident_type_in.name: - return get_by_name_or_raise( - db_session=db_session, project_id=project_id, incident_type_in=incident_type_in - ) + if incident_type_in and incident_type_in.name: + incident_type = get_by_name( + db_session=db_session, project_id=project_id, name=incident_type_in.name + ) + if incident_type: + return incident_type return get_default_or_raise(db_session=db_session, project_id=project_id) diff --git a/src/dispatch/organization/service.py b/src/dispatch/organization/service.py index ba4b7679b026..a34771aeb783 100644 --- a/src/dispatch/organization/service.py +++ b/src/dispatch/organization/service.py @@ -1,4 +1,3 @@ - from pydantic import ValidationError from sqlalchemy.sql.expression import true @@ -25,13 +24,15 @@ def get_default_or_raise(*, db_session) -> Organization: organization = get_default(db_session=db_session) if not organization: - raise ValidationError([ - { - "loc": ("organization",), - "msg": "No default organization defined.", - "type": "value_error", - } - ]) + raise ValidationError( + [ + { + "loc": ("organization",), + "msg": "No default organization defined.", + "type": "value_error", + } + ] + ) return organization @@ -85,10 +86,11 @@ def get_by_slug_or_raise(*, db_session, organization_in: OrganizationRead) -> Or def get_by_name_or_default(*, db_session, organization_in: OrganizationRead) -> Organization: """Returns a organization based on a name or the default if not specified.""" - if organization_in.name: - return get_by_name_or_raise(db_session=db_session, organization_in=organization_in) - else: - return get_default_or_raise(db_session=db_session) + if organization_in and organization_in.name: + organization = get_by_name(db_session=db_session, name=organization_in.name) + if organization: + return organization + return get_default_or_raise(db_session=db_session) def get_all(*, db_session) -> list[Organization | None]: diff --git a/src/dispatch/project/service.py b/src/dispatch/project/service.py index d5c505d711f8..962325423c03 100644 --- a/src/dispatch/project/service.py +++ b/src/dispatch/project/service.py @@ -1,4 +1,3 @@ - from pydantic import ValidationError from sqlalchemy.orm import Session from sqlalchemy.sql.expression import true @@ -22,13 +21,15 @@ def get_default_or_raise(*, db_session: Session) -> Project: project = get_default(db_session=db_session) if not project: - raise ValidationError([ - { - "loc": ("project",), - "msg": "No default project defined.", - "type": "value_error", - } - ]) + raise ValidationError( + [ + { + "loc": ("project",), + "msg": "No default project defined.", + "type": "value_error", + } + ] + ) return project @@ -57,9 +58,10 @@ def get_by_name_or_raise(*, db_session: Session, project_in: ProjectRead) -> Pro def get_by_name_or_default(*, db_session, project_in: ProjectRead) -> Project: """Returns a project based on a name or the default if not specified.""" - if project_in: - if project_in.name: - return get_by_name_or_raise(db_session=db_session, project_in=project_in) + if project_in and project_in.name: + project = get_by_name(db_session=db_session, name=project_in.name) + if project: + return project return get_default_or_raise(db_session=db_session) diff --git a/tests/case_type/test_case_type_service.py b/tests/case_type/test_case_type_service.py index b07c6932e18e..331b6cfe3ecc 100644 --- a/tests/case_type/test_case_type_service.py +++ b/tests/case_type/test_case_type_service.py @@ -1,6 +1,7 @@ import datetime from datetime import timezone + def test_get(session, case_type): from dispatch.case.type.service import get @@ -113,3 +114,31 @@ def test_delete(session, case_type): delete(db_session=session, case_type_id=case_type.id) assert not get(db_session=session, case_type_id=case_type.id) + + +def test_get_by_name_or_default__name(session, case_type): + from dispatch.case.type.models import CaseTypeRead + from dispatch.case.type.service import get_by_name_or_default + + case_type_in = CaseTypeRead.from_orm(case_type) + result = get_by_name_or_default( + db_session=session, project_id=case_type.project.id, case_type_in=case_type_in + ) + assert result.id == case_type.id + + +def test_get_by_name_or_default__default(session, case_type): + from dispatch.case.type.models import CaseTypeRead + from dispatch.case.type.service import get_by_name_or_default + + # Ensure only one default case type + for ct in session.query(type(case_type)).all(): + ct.default = False + case_type.default = True + session.commit() + # Pass a CaseTypeRead with a non-existent name and dummy id > 0 + case_type_in = CaseTypeRead(id=99999, name="nonexistent", project=case_type.project) + result = get_by_name_or_default( + db_session=session, project_id=case_type.project.id, case_type_in=case_type_in + ) + assert result.id == case_type.id diff --git a/tests/incident_priority/test_incident_priority_service.py b/tests/incident_priority/test_incident_priority_service.py index 0535cd5caa00..e55562e1ad42 100644 --- a/tests/incident_priority/test_incident_priority_service.py +++ b/tests/incident_priority/test_incident_priority_service.py @@ -59,3 +59,37 @@ def test_delete(session, incident_priority): delete(db_session=session, incident_priority_id=incident_priority.id) assert not get(db_session=session, incident_priority_id=incident_priority.id) + + +def test_get_by_name_or_default__name(session, incident_priority): + from dispatch.incident.priority.models import IncidentPriorityRead + from dispatch.incident.priority.service import get_by_name_or_default + + incident_priority_in = IncidentPriorityRead.from_orm(incident_priority) + result = get_by_name_or_default( + db_session=session, + project_id=incident_priority.project.id, + incident_priority_in=incident_priority_in, + ) + assert result.id == incident_priority.id + + +def test_get_by_name_or_default__default(session, incident_priority): + from dispatch.incident.priority.models import IncidentPriorityRead + from dispatch.incident.priority.service import get_by_name_or_default + + # Ensure only one default incident priority + for ip in session.query(type(incident_priority)).all(): + ip.default = False + incident_priority.default = True + session.commit() + # Pass an IncidentPriorityRead with a non-existent name and dummy id > 0 + incident_priority_in = IncidentPriorityRead( + id=99999, name="nonexistent", project=incident_priority.project + ) + result = get_by_name_or_default( + db_session=session, + project_id=incident_priority.project.id, + incident_priority_in=incident_priority_in, + ) + assert result.id == incident_priority.id diff --git a/tests/incident_type/test_incident_type_service.py b/tests/incident_type/test_incident_type_service.py index 8a6806f22587..453cd01562ba 100644 --- a/tests/incident_type/test_incident_type_service.py +++ b/tests/incident_type/test_incident_type_service.py @@ -1,6 +1,7 @@ import datetime from datetime import timezone + def test_get(session, incident_type): from dispatch.incident.type.service import get @@ -112,3 +113,31 @@ def test_delete(session, incident_type): delete(db_session=session, incident_type_id=incident_type.id) assert not get(db_session=session, incident_type_id=incident_type.id) + + +def test_get_by_name_or_default__name(session, incident_type): + from dispatch.incident.type.models import IncidentTypeRead + from dispatch.incident.type.service import get_by_name_or_default + + incident_type_in = IncidentTypeRead.from_orm(incident_type) + result = get_by_name_or_default( + db_session=session, project_id=incident_type.project.id, incident_type_in=incident_type_in + ) + assert result.id == incident_type.id + + +def test_get_by_name_or_default__default(session, incident_type): + from dispatch.incident.type.models import IncidentTypeRead + from dispatch.incident.type.service import get_by_name_or_default + + # Ensure only one default incident type + for it in session.query(type(incident_type)).all(): + it.default = False + incident_type.default = True + session.commit() + # Pass an IncidentTypeRead with a non-existent name and dummy id > 0 + incident_type_in = IncidentTypeRead(id=99999, name="nonexistent", project=incident_type.project) + result = get_by_name_or_default( + db_session=session, project_id=incident_type.project.id, incident_type_in=incident_type_in + ) + assert result.id == incident_type.id diff --git a/tests/organization/test_organization_service.py b/tests/organization/test_organization_service.py index f784ee92c61d..1b7f8a3c7879 100644 --- a/tests/organization/test_organization_service.py +++ b/tests/organization/test_organization_service.py @@ -59,3 +59,27 @@ def test_delete(session, organization): delete(db_session=session, organization_id=organization.id) assert not get(db_session=session, organization_id=organization.id) + + +def test_get_by_name_or_default__name(session, organization): + from dispatch.organization.models import OrganizationRead + from dispatch.organization.service import get_by_name_or_default + + organization_in = OrganizationRead.from_orm(organization) + result = get_by_name_or_default(db_session=session, organization_in=organization_in) + assert result.id == organization.id + + +def test_get_by_name_or_default__default(session, organization): + from dispatch.organization.models import OrganizationRead + from dispatch.organization.service import get_by_name_or_default + + # Ensure only one default organization + for org in session.query(type(organization)).all(): + org.default = False + organization.default = True + session.commit() + # Pass an OrganizationRead with a non-existent name + organization_in = OrganizationRead(name="nonexistent") + result = get_by_name_or_default(db_session=session, organization_in=organization_in) + assert result.id == organization.id diff --git a/tests/project/test_project_service.py b/tests/project/test_project_service.py index f41600fa55af..e39f1128f2d1 100644 --- a/tests/project/test_project_service.py +++ b/tests/project/test_project_service.py @@ -21,7 +21,7 @@ def test_create(session, organization): id=organization.id, name=organization.name, slug=organization.slug, - description=organization.description + description=organization.description, ) # Generate a random integer ID for the project to avoid collisions @@ -87,3 +87,27 @@ def test_delete(session, project): delete(db_session=session, project_id=project.id) assert not get(db_session=session, project_id=project.id) + + +def test_get_by_name_or_default__name(session, project): + from dispatch.project.models import ProjectRead + from dispatch.project.service import get_by_name_or_default + + project_in = ProjectRead.from_orm(project) + result = get_by_name_or_default(db_session=session, project_in=project_in) + assert result.id == project.id + + +def test_get_by_name_or_default__default(session, project, organization): + from dispatch.project.models import ProjectRead + from dispatch.project.service import get_by_name_or_default + + # Ensure only one default project + for p in session.query(type(project)).all(): + p.default = False + project.default = True + session.commit() + # Pass a ProjectRead with a non-existent name + project_in = ProjectRead(name="nonexistent", organization=organization) + result = get_by_name_or_default(db_session=session, project_in=project_in) + assert result.id == project.id diff --git a/tests/signal/test_signal_service.py b/tests/signal/test_signal_service.py index f267046c0aee..c253a7f33c68 100644 --- a/tests/signal/test_signal_service.py +++ b/tests/signal/test_signal_service.py @@ -90,7 +90,7 @@ def test_create(session, project, case_priority, case_type, service, tag, entity ) with pytest.raises(ValidationError) as exc_info: create(db_session=session, signal_in=signal_in) - assert "Value error, Case priority not found:" in str(exc_info.value) + assert "No default case priority defined." in str(exc_info.value) def test_update(session, project, signal, case_priority, case_type, service, tag, entity_type): @@ -262,7 +262,7 @@ def test_update__add_filter( signal=signal, signal_in=signal_in, ) - assert "Value error, Case priority not found:" in str(exc_info.value) + assert "No default case priority defined." in str(exc_info.value) def test_update__delete_filter( @@ -345,7 +345,7 @@ def test_update__delete_filter( signal=signal, signal_in=signal_in, ) - assert "Value error, Case priority not found:" in str(exc_info.value) + assert "No default case priority defined." in str(exc_info.value) def test_delete(session, signal):