diff --git a/diracx-client/src/diracx/client/_generated/_client.py b/diracx-client/src/diracx/client/_generated/_client.py index caf48034f..c5e641f24 100644 --- a/diracx-client/src/diracx/client/_generated/_client.py +++ b/diracx-client/src/diracx/client/_generated/_client.py @@ -15,7 +15,7 @@ from . import models as _models from ._configuration import DiracConfiguration from ._utils.serialization import Deserializer, Serializer -from .operations import AuthOperations, ConfigOperations, JobsOperations, WellKnownOperations +from .operations import AuthOperations, ConfigOperations, JobsOperations, RssOperations, WellKnownOperations class Dirac: # pylint: disable=client-accepts-api-version-keyword @@ -29,6 +29,8 @@ class Dirac: # pylint: disable=client-accepts-api-version-keyword :vartype config: _generated.operations.ConfigOperations :ivar jobs: JobsOperations operations :vartype jobs: _generated.operations.JobsOperations + :ivar rss: RssOperations operations + :vartype rss: _generated.operations.RssOperations :keyword endpoint: Service URL. Required. Default value is "". :paramtype endpoint: str """ @@ -65,6 +67,7 @@ def __init__( # pylint: disable=missing-client-constructor-parameter-credential self.auth = AuthOperations(self._client, self._config, self._serialize, self._deserialize) self.config = ConfigOperations(self._client, self._config, self._serialize, self._deserialize) self.jobs = JobsOperations(self._client, self._config, self._serialize, self._deserialize) + self.rss = RssOperations(self._client, self._config, self._serialize, self._deserialize) def send_request(self, request: HttpRequest, *, stream: bool = False, **kwargs: Any) -> HttpResponse: """Runs the network request through the client's chained policies. diff --git a/diracx-client/src/diracx/client/_generated/aio/_client.py b/diracx-client/src/diracx/client/_generated/aio/_client.py index 79ab383a9..f2d2b46bf 100644 --- a/diracx-client/src/diracx/client/_generated/aio/_client.py +++ b/diracx-client/src/diracx/client/_generated/aio/_client.py @@ -15,7 +15,7 @@ from .. import models as _models from .._utils.serialization import Deserializer, Serializer from ._configuration import DiracConfiguration -from .operations import AuthOperations, ConfigOperations, JobsOperations, WellKnownOperations +from .operations import AuthOperations, ConfigOperations, JobsOperations, RssOperations, WellKnownOperations class Dirac: # pylint: disable=client-accepts-api-version-keyword @@ -29,6 +29,8 @@ class Dirac: # pylint: disable=client-accepts-api-version-keyword :vartype config: _generated.aio.operations.ConfigOperations :ivar jobs: JobsOperations operations :vartype jobs: _generated.aio.operations.JobsOperations + :ivar rss: RssOperations operations + :vartype rss: _generated.aio.operations.RssOperations :keyword endpoint: Service URL. Required. Default value is "". :paramtype endpoint: str """ @@ -65,6 +67,7 @@ def __init__( # pylint: disable=missing-client-constructor-parameter-credential self.auth = AuthOperations(self._client, self._config, self._serialize, self._deserialize) self.config = ConfigOperations(self._client, self._config, self._serialize, self._deserialize) self.jobs = JobsOperations(self._client, self._config, self._serialize, self._deserialize) + self.rss = RssOperations(self._client, self._config, self._serialize, self._deserialize) def send_request( self, request: HttpRequest, *, stream: bool = False, **kwargs: Any diff --git a/diracx-client/src/diracx/client/_generated/aio/operations/__init__.py b/diracx-client/src/diracx/client/_generated/aio/operations/__init__.py index 6be34fb8a..77674abec 100644 --- a/diracx-client/src/diracx/client/_generated/aio/operations/__init__.py +++ b/diracx-client/src/diracx/client/_generated/aio/operations/__init__.py @@ -14,6 +14,7 @@ from ._operations import AuthOperations # type: ignore from ._operations import ConfigOperations # type: ignore from ._operations import JobsOperations # type: ignore +from ._operations import RssOperations # type: ignore from ._patch import __all__ as _patch_all from ._patch import * @@ -24,6 +25,7 @@ "AuthOperations", "ConfigOperations", "JobsOperations", + "RssOperations", ] __all__.extend([p for p in _patch_all if p not in __all__]) # pyright: ignore _patch_sdk() diff --git a/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py b/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py index 8aee57b46..4efe2dd6f 100644 --- a/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py +++ b/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py @@ -51,6 +51,10 @@ build_jobs_summary_request, build_jobs_unassign_bulk_jobs_sandboxes_request, build_jobs_unassign_job_sandboxes_request, + build_rss_get_compute_status_request, + build_rss_get_fts_status_request, + build_rss_get_site_status_request, + build_rss_get_storage_status_request, build_well_known_get_installation_metadata_request, build_well_known_get_jwks_request, build_well_known_get_openid_configuration_request, @@ -2319,3 +2323,303 @@ async def submit_jdl_jobs(self, body: Union[list[str], IO[bytes]], **kwargs: Any return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized # type: ignore + + +class RssOperations: + """ + .. warning:: + **DO NOT** instantiate this class directly. + + Instead, you should access the following operations through + :class:`~_generated.aio.Dirac`'s + :attr:`rss` attribute. + """ + + models = _models + + def __init__(self, *args, **kwargs) -> None: + input_args = list(args) + self._client: AsyncPipelineClient = input_args.pop(0) if input_args else kwargs.pop("client") + self._config: DiracConfiguration = input_args.pop(0) if input_args else kwargs.pop("config") + self._serialize: Serializer = input_args.pop(0) if input_args else kwargs.pop("serializer") + self._deserialize: Deserializer = input_args.pop(0) if input_args else kwargs.pop("deserializer") + + @distributed_trace_async + async def get_storage_status( + self, + *, + if_modified_since: Optional[str] = None, + etag: Optional[str] = None, + match_condition: Optional[MatchConditions] = None, + **kwargs: Any + ) -> dict[str, _models.StorageElementStatus]: + """Get Storage Status. + + Get the latest status of storage elements, scoped to the caller's VO. + + :keyword if_modified_since: Default value is None. + :paramtype if_modified_since: str + :keyword etag: check if resource is changed. Set None to skip checking etag. Default value is + None. + :paramtype etag: str + :keyword match_condition: The match condition to use upon the etag. Default value is None. + :paramtype match_condition: ~azure.core.MatchConditions + :return: dict mapping str to StorageElementStatus + :rtype: dict[str, ~_generated.models.StorageElementStatus] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + if match_condition == MatchConditions.IfNotModified: + error_map[412] = ResourceModifiedError + elif match_condition == MatchConditions.IfPresent: + error_map[412] = ResourceNotFoundError + elif match_condition == MatchConditions.IfMissing: + error_map[412] = ResourceExistsError + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[dict[str, _models.StorageElementStatus]] = kwargs.pop("cls", None) + + _request = build_rss_get_storage_status_request( + if_modified_since=if_modified_since, + etag=etag, + match_condition=match_condition, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("{StorageElementStatus}", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @distributed_trace_async + async def get_compute_status( + self, + *, + if_modified_since: Optional[str] = None, + etag: Optional[str] = None, + match_condition: Optional[MatchConditions] = None, + **kwargs: Any + ) -> dict[str, _models.ComputeElementStatus]: + """Get Compute Status. + + Get the latest status of compute elements, scoped to the caller's VO. + + :keyword if_modified_since: Default value is None. + :paramtype if_modified_since: str + :keyword etag: check if resource is changed. Set None to skip checking etag. Default value is + None. + :paramtype etag: str + :keyword match_condition: The match condition to use upon the etag. Default value is None. + :paramtype match_condition: ~azure.core.MatchConditions + :return: dict mapping str to ComputeElementStatus + :rtype: dict[str, ~_generated.models.ComputeElementStatus] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + if match_condition == MatchConditions.IfNotModified: + error_map[412] = ResourceModifiedError + elif match_condition == MatchConditions.IfPresent: + error_map[412] = ResourceNotFoundError + elif match_condition == MatchConditions.IfMissing: + error_map[412] = ResourceExistsError + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[dict[str, _models.ComputeElementStatus]] = kwargs.pop("cls", None) + + _request = build_rss_get_compute_status_request( + if_modified_since=if_modified_since, + etag=etag, + match_condition=match_condition, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("{ComputeElementStatus}", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @distributed_trace_async + async def get_site_status( + self, + *, + if_modified_since: Optional[str] = None, + etag: Optional[str] = None, + match_condition: Optional[MatchConditions] = None, + **kwargs: Any + ) -> dict[str, _models.SiteStatus]: + """Get Site Status. + + Get the latest status of sites, scoped to the caller's VO. + + :keyword if_modified_since: Default value is None. + :paramtype if_modified_since: str + :keyword etag: check if resource is changed. Set None to skip checking etag. Default value is + None. + :paramtype etag: str + :keyword match_condition: The match condition to use upon the etag. Default value is None. + :paramtype match_condition: ~azure.core.MatchConditions + :return: dict mapping str to SiteStatus + :rtype: dict[str, ~_generated.models.SiteStatus] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + if match_condition == MatchConditions.IfNotModified: + error_map[412] = ResourceModifiedError + elif match_condition == MatchConditions.IfPresent: + error_map[412] = ResourceNotFoundError + elif match_condition == MatchConditions.IfMissing: + error_map[412] = ResourceExistsError + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[dict[str, _models.SiteStatus]] = kwargs.pop("cls", None) + + _request = build_rss_get_site_status_request( + if_modified_since=if_modified_since, + etag=etag, + match_condition=match_condition, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("{SiteStatus}", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @distributed_trace_async + async def get_fts_status( + self, + *, + if_modified_since: Optional[str] = None, + etag: Optional[str] = None, + match_condition: Optional[MatchConditions] = None, + **kwargs: Any + ) -> dict[str, _models.FTSStatus]: + """Get Fts Status. + + Get the latest status of FTS servers, scoped to the caller's VO. + + :keyword if_modified_since: Default value is None. + :paramtype if_modified_since: str + :keyword etag: check if resource is changed. Set None to skip checking etag. Default value is + None. + :paramtype etag: str + :keyword match_condition: The match condition to use upon the etag. Default value is None. + :paramtype match_condition: ~azure.core.MatchConditions + :return: dict mapping str to FTSStatus + :rtype: dict[str, ~_generated.models.FTSStatus] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + if match_condition == MatchConditions.IfNotModified: + error_map[412] = ResourceModifiedError + elif match_condition == MatchConditions.IfPresent: + error_map[412] = ResourceNotFoundError + elif match_condition == MatchConditions.IfMissing: + error_map[412] = ResourceExistsError + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[dict[str, _models.FTSStatus]] = kwargs.pop("cls", None) + + _request = build_rss_get_fts_status_request( + if_modified_since=if_modified_since, + etag=etag, + match_condition=match_condition, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("{FTSStatus}", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore diff --git a/diracx-client/src/diracx/client/_generated/models/__init__.py b/diracx-client/src/diracx/client/_generated/models/__init__.py index 14b5195d4..b4c06cc69 100644 --- a/diracx-client/src/diracx/client/_generated/models/__init__.py +++ b/diracx-client/src/diracx/client/_generated/models/__init__.py @@ -12,10 +12,16 @@ from ._models import ( # type: ignore + AllowedStatus, + BannedStatus, BodyAuthGetOidcToken, BodyAuthRevokeRefreshTokenByRefreshToken, BodyJobsRescheduleJobs, BodyJobsUnassignBulkJobsSandboxes, + ComputeElementStatus, + ComputeElementStatusAll, + FTSStatus, + FTSStatusAll, GroupInfo, HTTPValidationError, HeartbeatData, @@ -34,7 +40,14 @@ SearchParamsSearchItem, SetJobStatusReturn, SetJobStatusReturnSuccess, + SiteStatus, + SiteStatusAll, SortSpec, + StorageElementStatus, + StorageElementStatusCheck, + StorageElementStatusRead, + StorageElementStatusRemove, + StorageElementStatusWrite, SummaryParams, SummaryParamsSearchItem, SupportInfo, @@ -59,10 +72,16 @@ from ._patch import patch_sdk as _patch_sdk __all__ = [ + "AllowedStatus", + "BannedStatus", "BodyAuthGetOidcToken", "BodyAuthRevokeRefreshTokenByRefreshToken", "BodyJobsRescheduleJobs", "BodyJobsUnassignBulkJobsSandboxes", + "ComputeElementStatus", + "ComputeElementStatusAll", + "FTSStatus", + "FTSStatusAll", "GroupInfo", "HTTPValidationError", "HeartbeatData", @@ -81,7 +100,14 @@ "SearchParamsSearchItem", "SetJobStatusReturn", "SetJobStatusReturnSuccess", + "SiteStatus", + "SiteStatusAll", "SortSpec", + "StorageElementStatus", + "StorageElementStatusCheck", + "StorageElementStatusRead", + "StorageElementStatusRemove", + "StorageElementStatusWrite", "SummaryParams", "SummaryParamsSearchItem", "SupportInfo", diff --git a/diracx-client/src/diracx/client/_generated/models/_models.py b/diracx-client/src/diracx/client/_generated/models/_models.py index 888ec3b8a..730f15d6d 100644 --- a/diracx-client/src/diracx/client/_generated/models/_models.py +++ b/diracx-client/src/diracx/client/_generated/models/_models.py @@ -16,6 +16,70 @@ JSON = MutableMapping[str, Any] +class AllowedStatus(_serialization.Model): + """AllowedStatus. + + All required parameters must be populated in order to send to server. + + :ivar allowed: Allowed. Required. + :vartype allowed: bool + :ivar warnings: Warnings. + :vartype warnings: str + """ + + _validation = { + "allowed": {"required": True}, + } + + _attribute_map = { + "allowed": {"key": "allowed", "type": "bool"}, + "warnings": {"key": "warnings", "type": "str"}, + } + + def __init__(self, *, allowed: bool, warnings: Optional[str] = None, **kwargs: Any) -> None: + """ + :keyword allowed: Allowed. Required. + :paramtype allowed: bool + :keyword warnings: Warnings. + :paramtype warnings: str + """ + super().__init__(**kwargs) + self.allowed = allowed + self.warnings = warnings + + +class BannedStatus(_serialization.Model): + """BannedStatus. + + All required parameters must be populated in order to send to server. + + :ivar allowed: Allowed. Required. + :vartype allowed: bool + :ivar reason: Reason. + :vartype reason: str + """ + + _validation = { + "allowed": {"required": True}, + } + + _attribute_map = { + "allowed": {"key": "allowed", "type": "bool"}, + "reason": {"key": "reason", "type": "str"}, + } + + def __init__(self, *, allowed: bool, reason: str = "Unknown", **kwargs: Any) -> None: + """ + :keyword allowed: Allowed. Required. + :paramtype allowed: bool + :keyword reason: Reason. + :paramtype reason: str + """ + super().__init__(**kwargs) + self.allowed = allowed + self.reason = reason + + class BodyAuthGetOidcToken(_serialization.Model): """Body_auth_get_oidc_token. @@ -184,6 +248,66 @@ def __init__(self, *, job_ids: list[int], **kwargs: Any) -> None: self.job_ids = job_ids +class ComputeElementStatus(_serialization.Model): + """ComputeElementStatus. + + All required parameters must be populated in order to send to server. + + :ivar all: All. Required. + :vartype all: ~_generated.models.ComputeElementStatusAll + """ + + _validation = { + "all": {"required": True}, + } + + _attribute_map = { + "all": {"key": "all", "type": "ComputeElementStatusAll"}, + } + + def __init__(self, *, all: "_models.ComputeElementStatusAll", **kwargs: Any) -> None: + """ + :keyword all: All. Required. + :paramtype all: ~_generated.models.ComputeElementStatusAll + """ + super().__init__(**kwargs) + self.all = all + + +class ComputeElementStatusAll(_serialization.Model): + """All.""" + + +class FTSStatus(_serialization.Model): + """FTSStatus. + + All required parameters must be populated in order to send to server. + + :ivar all: All. Required. + :vartype all: ~_generated.models.FTSStatusAll + """ + + _validation = { + "all": {"required": True}, + } + + _attribute_map = { + "all": {"key": "all", "type": "FTSStatusAll"}, + } + + def __init__(self, *, all: "_models.FTSStatusAll", **kwargs: Any) -> None: + """ + :keyword all: All. Required. + :paramtype all: ~_generated.models.FTSStatusAll + """ + super().__init__(**kwargs) + self.all = all + + +class FTSStatusAll(_serialization.Model): + """All.""" + + class GroupInfo(_serialization.Model): """GroupInfo. @@ -1261,6 +1385,36 @@ def __init__( self.last_update_time = last_update_time +class SiteStatus(_serialization.Model): + """SiteStatus. + + All required parameters must be populated in order to send to server. + + :ivar all: All. Required. + :vartype all: ~_generated.models.SiteStatusAll + """ + + _validation = { + "all": {"required": True}, + } + + _attribute_map = { + "all": {"key": "all", "type": "SiteStatusAll"}, + } + + def __init__(self, *, all: "_models.SiteStatusAll", **kwargs: Any) -> None: + """ + :keyword all: All. Required. + :paramtype all: ~_generated.models.SiteStatusAll + """ + super().__init__(**kwargs) + self.all = all + + +class SiteStatusAll(_serialization.Model): + """All.""" + + class SortSpec(_serialization.Model): """SortSpec. @@ -1294,6 +1448,77 @@ def __init__(self, *, parameter: str, direction: Union[str, "_models.SortDirecti self.direction = direction +class StorageElementStatus(_serialization.Model): + """StorageElementStatus. + + All required parameters must be populated in order to send to server. + + :ivar read: Read. Required. + :vartype read: ~_generated.models.StorageElementStatusRead + :ivar write: Write. Required. + :vartype write: ~_generated.models.StorageElementStatusWrite + :ivar check: Check. Required. + :vartype check: ~_generated.models.StorageElementStatusCheck + :ivar remove: Remove. Required. + :vartype remove: ~_generated.models.StorageElementStatusRemove + """ + + _validation = { + "read": {"required": True}, + "write": {"required": True}, + "check": {"required": True}, + "remove": {"required": True}, + } + + _attribute_map = { + "read": {"key": "read", "type": "StorageElementStatusRead"}, + "write": {"key": "write", "type": "StorageElementStatusWrite"}, + "check": {"key": "check", "type": "StorageElementStatusCheck"}, + "remove": {"key": "remove", "type": "StorageElementStatusRemove"}, + } + + def __init__( + self, + *, + read: "_models.StorageElementStatusRead", + write: "_models.StorageElementStatusWrite", + check: "_models.StorageElementStatusCheck", + remove: "_models.StorageElementStatusRemove", + **kwargs: Any + ) -> None: + """ + :keyword read: Read. Required. + :paramtype read: ~_generated.models.StorageElementStatusRead + :keyword write: Write. Required. + :paramtype write: ~_generated.models.StorageElementStatusWrite + :keyword check: Check. Required. + :paramtype check: ~_generated.models.StorageElementStatusCheck + :keyword remove: Remove. Required. + :paramtype remove: ~_generated.models.StorageElementStatusRemove + """ + super().__init__(**kwargs) + self.read = read + self.write = write + self.check = check + self.remove = remove + + +class StorageElementStatusCheck(_serialization.Model): + """Check.""" + + +class StorageElementStatusRead(_serialization.Model): + """Read.""" + + +class StorageElementStatusRemove(_serialization.Model): + """Remove.""" + + +class StorageElementStatusWrite(_serialization.Model): + """Write.""" + + class SummaryParams(_serialization.Model): """SummaryParams. diff --git a/diracx-client/src/diracx/client/_generated/operations/__init__.py b/diracx-client/src/diracx/client/_generated/operations/__init__.py index 6be34fb8a..77674abec 100644 --- a/diracx-client/src/diracx/client/_generated/operations/__init__.py +++ b/diracx-client/src/diracx/client/_generated/operations/__init__.py @@ -14,6 +14,7 @@ from ._operations import AuthOperations # type: ignore from ._operations import ConfigOperations # type: ignore from ._operations import JobsOperations # type: ignore +from ._operations import RssOperations # type: ignore from ._patch import __all__ as _patch_all from ._patch import * @@ -24,6 +25,7 @@ "AuthOperations", "ConfigOperations", "JobsOperations", + "RssOperations", ] __all__.extend([p for p in _patch_all if p not in __all__]) # pyright: ignore _patch_sdk() diff --git a/diracx-client/src/diracx/client/_generated/operations/_operations.py b/diracx-client/src/diracx/client/_generated/operations/_operations.py index 11ffdcff7..69089682a 100644 --- a/diracx-client/src/diracx/client/_generated/operations/_operations.py +++ b/diracx-client/src/diracx/client/_generated/operations/_operations.py @@ -565,6 +565,118 @@ def build_jobs_submit_jdl_jobs_request(**kwargs: Any) -> HttpRequest: return HttpRequest(method="POST", url=_url, headers=_headers, **kwargs) +def build_rss_get_storage_status_request( + *, + if_modified_since: Optional[str] = None, + etag: Optional[str] = None, + match_condition: Optional[MatchConditions] = None, + **kwargs: Any +) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/rss/storage" + + # Construct headers + if if_modified_since is not None: + _headers["if-modified-since"] = _SERIALIZER.header("if_modified_since", if_modified_since, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + if_match = prep_if_match(etag, match_condition) + if if_match is not None: + _headers["If-Match"] = _SERIALIZER.header("if_match", if_match, "str") + if_none_match = prep_if_none_match(etag, match_condition) + if if_none_match is not None: + _headers["If-None-Match"] = _SERIALIZER.header("if_none_match", if_none_match, "str") + + return HttpRequest(method="GET", url=_url, headers=_headers, **kwargs) + + +def build_rss_get_compute_status_request( + *, + if_modified_since: Optional[str] = None, + etag: Optional[str] = None, + match_condition: Optional[MatchConditions] = None, + **kwargs: Any +) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/rss/compute" + + # Construct headers + if if_modified_since is not None: + _headers["if-modified-since"] = _SERIALIZER.header("if_modified_since", if_modified_since, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + if_match = prep_if_match(etag, match_condition) + if if_match is not None: + _headers["If-Match"] = _SERIALIZER.header("if_match", if_match, "str") + if_none_match = prep_if_none_match(etag, match_condition) + if if_none_match is not None: + _headers["If-None-Match"] = _SERIALIZER.header("if_none_match", if_none_match, "str") + + return HttpRequest(method="GET", url=_url, headers=_headers, **kwargs) + + +def build_rss_get_site_status_request( + *, + if_modified_since: Optional[str] = None, + etag: Optional[str] = None, + match_condition: Optional[MatchConditions] = None, + **kwargs: Any +) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/rss/site" + + # Construct headers + if if_modified_since is not None: + _headers["if-modified-since"] = _SERIALIZER.header("if_modified_since", if_modified_since, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + if_match = prep_if_match(etag, match_condition) + if if_match is not None: + _headers["If-Match"] = _SERIALIZER.header("if_match", if_match, "str") + if_none_match = prep_if_none_match(etag, match_condition) + if if_none_match is not None: + _headers["If-None-Match"] = _SERIALIZER.header("if_none_match", if_none_match, "str") + + return HttpRequest(method="GET", url=_url, headers=_headers, **kwargs) + + +def build_rss_get_fts_status_request( + *, + if_modified_since: Optional[str] = None, + etag: Optional[str] = None, + match_condition: Optional[MatchConditions] = None, + **kwargs: Any +) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/rss/fts" + + # Construct headers + if if_modified_since is not None: + _headers["if-modified-since"] = _SERIALIZER.header("if_modified_since", if_modified_since, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + if_match = prep_if_match(etag, match_condition) + if if_match is not None: + _headers["If-Match"] = _SERIALIZER.header("if_match", if_match, "str") + if_none_match = prep_if_none_match(etag, match_condition) + if if_none_match is not None: + _headers["If-None-Match"] = _SERIALIZER.header("if_none_match", if_none_match, "str") + + return HttpRequest(method="GET", url=_url, headers=_headers, **kwargs) + + class WellKnownOperations: """ .. warning:: @@ -2818,3 +2930,303 @@ def submit_jdl_jobs(self, body: Union[list[str], IO[bytes]], **kwargs: Any) -> l return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized # type: ignore + + +class RssOperations: + """ + .. warning:: + **DO NOT** instantiate this class directly. + + Instead, you should access the following operations through + :class:`~_generated.Dirac`'s + :attr:`rss` attribute. + """ + + models = _models + + def __init__(self, *args, **kwargs) -> None: + input_args = list(args) + self._client: PipelineClient = input_args.pop(0) if input_args else kwargs.pop("client") + self._config: DiracConfiguration = input_args.pop(0) if input_args else kwargs.pop("config") + self._serialize: Serializer = input_args.pop(0) if input_args else kwargs.pop("serializer") + self._deserialize: Deserializer = input_args.pop(0) if input_args else kwargs.pop("deserializer") + + @distributed_trace + def get_storage_status( + self, + *, + if_modified_since: Optional[str] = None, + etag: Optional[str] = None, + match_condition: Optional[MatchConditions] = None, + **kwargs: Any + ) -> dict[str, _models.StorageElementStatus]: + """Get Storage Status. + + Get the latest status of storage elements, scoped to the caller's VO. + + :keyword if_modified_since: Default value is None. + :paramtype if_modified_since: str + :keyword etag: check if resource is changed. Set None to skip checking etag. Default value is + None. + :paramtype etag: str + :keyword match_condition: The match condition to use upon the etag. Default value is None. + :paramtype match_condition: ~azure.core.MatchConditions + :return: dict mapping str to StorageElementStatus + :rtype: dict[str, ~_generated.models.StorageElementStatus] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + if match_condition == MatchConditions.IfNotModified: + error_map[412] = ResourceModifiedError + elif match_condition == MatchConditions.IfPresent: + error_map[412] = ResourceNotFoundError + elif match_condition == MatchConditions.IfMissing: + error_map[412] = ResourceExistsError + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[dict[str, _models.StorageElementStatus]] = kwargs.pop("cls", None) + + _request = build_rss_get_storage_status_request( + if_modified_since=if_modified_since, + etag=etag, + match_condition=match_condition, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("{StorageElementStatus}", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @distributed_trace + def get_compute_status( + self, + *, + if_modified_since: Optional[str] = None, + etag: Optional[str] = None, + match_condition: Optional[MatchConditions] = None, + **kwargs: Any + ) -> dict[str, _models.ComputeElementStatus]: + """Get Compute Status. + + Get the latest status of compute elements, scoped to the caller's VO. + + :keyword if_modified_since: Default value is None. + :paramtype if_modified_since: str + :keyword etag: check if resource is changed. Set None to skip checking etag. Default value is + None. + :paramtype etag: str + :keyword match_condition: The match condition to use upon the etag. Default value is None. + :paramtype match_condition: ~azure.core.MatchConditions + :return: dict mapping str to ComputeElementStatus + :rtype: dict[str, ~_generated.models.ComputeElementStatus] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + if match_condition == MatchConditions.IfNotModified: + error_map[412] = ResourceModifiedError + elif match_condition == MatchConditions.IfPresent: + error_map[412] = ResourceNotFoundError + elif match_condition == MatchConditions.IfMissing: + error_map[412] = ResourceExistsError + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[dict[str, _models.ComputeElementStatus]] = kwargs.pop("cls", None) + + _request = build_rss_get_compute_status_request( + if_modified_since=if_modified_since, + etag=etag, + match_condition=match_condition, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("{ComputeElementStatus}", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @distributed_trace + def get_site_status( + self, + *, + if_modified_since: Optional[str] = None, + etag: Optional[str] = None, + match_condition: Optional[MatchConditions] = None, + **kwargs: Any + ) -> dict[str, _models.SiteStatus]: + """Get Site Status. + + Get the latest status of sites, scoped to the caller's VO. + + :keyword if_modified_since: Default value is None. + :paramtype if_modified_since: str + :keyword etag: check if resource is changed. Set None to skip checking etag. Default value is + None. + :paramtype etag: str + :keyword match_condition: The match condition to use upon the etag. Default value is None. + :paramtype match_condition: ~azure.core.MatchConditions + :return: dict mapping str to SiteStatus + :rtype: dict[str, ~_generated.models.SiteStatus] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + if match_condition == MatchConditions.IfNotModified: + error_map[412] = ResourceModifiedError + elif match_condition == MatchConditions.IfPresent: + error_map[412] = ResourceNotFoundError + elif match_condition == MatchConditions.IfMissing: + error_map[412] = ResourceExistsError + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[dict[str, _models.SiteStatus]] = kwargs.pop("cls", None) + + _request = build_rss_get_site_status_request( + if_modified_since=if_modified_since, + etag=etag, + match_condition=match_condition, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("{SiteStatus}", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @distributed_trace + def get_fts_status( + self, + *, + if_modified_since: Optional[str] = None, + etag: Optional[str] = None, + match_condition: Optional[MatchConditions] = None, + **kwargs: Any + ) -> dict[str, _models.FTSStatus]: + """Get Fts Status. + + Get the latest status of FTS servers, scoped to the caller's VO. + + :keyword if_modified_since: Default value is None. + :paramtype if_modified_since: str + :keyword etag: check if resource is changed. Set None to skip checking etag. Default value is + None. + :paramtype etag: str + :keyword match_condition: The match condition to use upon the etag. Default value is None. + :paramtype match_condition: ~azure.core.MatchConditions + :return: dict mapping str to FTSStatus + :rtype: dict[str, ~_generated.models.FTSStatus] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + if match_condition == MatchConditions.IfNotModified: + error_map[412] = ResourceModifiedError + elif match_condition == MatchConditions.IfPresent: + error_map[412] = ResourceNotFoundError + elif match_condition == MatchConditions.IfMissing: + error_map[412] = ResourceExistsError + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[dict[str, _models.FTSStatus]] = kwargs.pop("cls", None) + + _request = build_rss_get_fts_status_request( + if_modified_since=if_modified_since, + etag=etag, + match_condition=match_condition, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("{FTSStatus}", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore diff --git a/diracx-core/src/diracx/core/config/__init__.py b/diracx-core/src/diracx/core/config/__init__.py index 35d5fa4e9..77f62e081 100644 --- a/diracx-core/src/diracx/core/config/__init__.py +++ b/diracx-core/src/diracx/core/config/__init__.py @@ -3,6 +3,8 @@ from __future__ import annotations __all__ = [ + "AsyncCacheableSource", + "CacheableSource", "Config", "ConfigSource", "ConfigSourceUrl", @@ -14,6 +16,7 @@ "RegistryConfig", "RemoteGitConfigSource", "SerializableSet", + "Snapshot", "SupportInfo", "UserConfig", "is_running_in_async_context", @@ -31,9 +34,12 @@ UserConfig, ) from .sources import ( + AsyncCacheableSource, + CacheableSource, ConfigSource, ConfigSourceUrl, LocalGitConfigSource, RemoteGitConfigSource, + Snapshot, is_running_in_async_context, ) diff --git a/diracx-core/src/diracx/core/config/sources.py b/diracx-core/src/diracx/core/config/sources.py index f16fffa82..9853b79cb 100644 --- a/diracx-core/src/diracx/core/config/sources.py +++ b/diracx-core/src/diracx/core/config/sources.py @@ -9,10 +9,11 @@ import logging import os from abc import ABCMeta, abstractmethod +from dataclasses import dataclass from datetime import datetime, timezone from pathlib import Path from tempfile import TemporaryDirectory -from typing import Annotated, Generic, TypeVar +from typing import Annotated, ClassVar, Generic, TypeVar from urllib.parse import urlparse, urlunparse import sh @@ -22,7 +23,7 @@ from diracx.core.exceptions import BadConfigurationVersionError from diracx.core.extensions import DiracEntryPoint, select_from_extension -from diracx.core.utils import TwoLevelCache +from diracx.core.utils import AsyncTwoLevelCache, TwoLevelCache from .schema import Config @@ -139,6 +140,79 @@ def clear_caches(self): self._content_cache.clear() +@dataclass(frozen=True) +class Snapshot(Generic[T]): + """Wraps a cached data payload with its cache metadata.""" + + data: T + hexsha: str + modified: datetime + + +class AsyncCacheableSource(Generic[T], metaclass=ABCMeta): + """Abstract base class for async sources that can be cached. + + Async equivalent of CacheableSource. Uses AsyncTwoLevelCache so populate + functions are native coroutines. + """ + + #: The database class this source reads from. Used by the application + #: factory to instantiate the source with the matching database instance. + db_class: ClassVar[type] + + def __init__(self): + self._revision_cache = AsyncTwoLevelCache( + soft_ttl=DEFAULT_CS_REV_CACHE_SOFT_TTL, + hard_ttl=DEFAULT_CS_REV_CACHE_HARD_TTL, + max_items=1, + ) + self._content_cache: Cache = LRUCache(maxsize=2) + + @abstractmethod + async def latest_revision(self) -> tuple[str, datetime]: + """Return (revision_str, modified) identifying the current revision.""" + + @abstractmethod + async def read_raw(self, hexsha: str, modified: datetime) -> T: + """Fetch and return the data for the given revision.""" + + async def _read_work(self) -> str: + hexsha, modified = await self.latest_revision() + if hexsha not in self._content_cache: + self._content_cache[hexsha] = await self.read_raw(hexsha, modified) + return hexsha + + async def read(self) -> T: + """Blocking read — awaits refresh on a hard cache miss.""" + hexsha = await self._revision_cache.get( + "latest_revision", self._read_work, blocking=True + ) + return self._content_cache[hexsha] + + async def read_non_blocking(self) -> T: + """Non-blocking read — raises NotReadyError on a hard cache miss.""" + hexsha = await self._revision_cache.get( + "latest_revision", self._read_work, blocking=False + ) + return self._content_cache[hexsha] + + async def clear_caches(self): + """Clear the caches.""" + await self._revision_cache.clear() + self._content_cache.clear() + + @classmethod + async def create(cls) -> T: + """Dependency injection stub. + + The application factory instantiates each concrete source and + overrides ``cls.create`` with the instance's ``read`` method, so this + should never actually be called. Each subclass's bound ``create`` + classmethod is a distinct dependency key. + """ + raise NotImplementedError(f"{cls.__name__} was not wired by the factory") + + class ConfigSource(CacheableSource[Config]): """Abstract class for the configuration source. diff --git a/diracx-core/src/diracx/core/extensions.py b/diracx-core/src/diracx/core/extensions.py index 4aeac7c15..7a261208f 100644 --- a/diracx-core/src/diracx/core/extensions.py +++ b/diracx-core/src/diracx/core/extensions.py @@ -23,6 +23,7 @@ class DiracEntryPoint(StrEnum): CORE = "diracx" ACCESS_POLICY = "diracx.access_policies" + CACHEABLE_SOURCES = "diracx.cacheable_sources" CLI = "diracx.cli" HIDDEN_CLI = "diracx.cli.hidden" OS_DB = "diracx.dbs.os" diff --git a/diracx-core/src/diracx/core/utils.py b/diracx-core/src/diracx/core/utils.py index 8949d07f3..c1ac4c7a3 100644 --- a/diracx-core/src/diracx/core/utils.py +++ b/diracx-core/src/diracx/core/utils.py @@ -2,6 +2,7 @@ __all__ = [ "EXPIRES_GRACE_SECONDS", + "AsyncTwoLevelCache", "TwoLevelCache", "batched_async", "dotenv_files_from_environment", @@ -11,6 +12,7 @@ "write_credentials", ] +import asyncio import fcntl import json import logging @@ -19,7 +21,7 @@ import stat import threading from collections import defaultdict -from collections.abc import Callable +from collections.abc import Callable, Coroutine from concurrent.futures import Future, ThreadPoolExecutor from datetime import datetime, timedelta, timezone from pathlib import Path @@ -293,6 +295,152 @@ def clear(self): self.locks.clear() +class AsyncTwoLevelCache: + """Async equivalent of TwoLevelCache, for use with async populate functions. + + Mirrors the two-TTL semantics of TwoLevelCache exactly: a soft TTL that + triggers a background refresh while still serving the stale value, and a + hard TTL beyond which a miss either awaits a fresh value (blocking=True) or + raises NotReadyError (blocking=False). + + The key difference from TwoLevelCache is that all coordination uses asyncio + primitives (asyncio.Lock, asyncio.Task) rather than a ThreadPoolExecutor, + so populate_func can be a native coroutine. + + Attributes: + soft_cache (TTLCache): A cache with a shorter TTL for quick access. + hard_cache (TTLCache): A cache with a longer TTL as a fallback. + tasks (dict): In-flight refresh Tasks keyed by cache key. + _lock (asyncio.Lock): Guards task creation to ensure single-flight behaviour. + + Args: + soft_ttl (int): Time-to-live in seconds for the soft cache. + hard_ttl (int): Time-to-live in seconds for the hard cache. + max_items (int): Maximum number of items in each cache tier. + + Example: + >>> cache = AsyncTwoLevelCache(soft_ttl=5, hard_ttl=3600) + >>> async def populate(): + ... return await some_db_query() + >>> value = await cache.get("key", populate) + + """ + + def __init__( + self, + soft_ttl: int, + hard_ttl: int, + *, + max_items: int = 1_000_000, + ): + """Initialize the AsyncTwoLevelCache with specified TTLs.""" + self.soft_cache: Cache = TTLCache(max_items, soft_ttl) + self.hard_cache: Cache = TTLCache(max_items, hard_ttl) + # One Task per key for single-flight refresh deduplication. + self.tasks: dict[str, asyncio.Task] = {} + # A single lock guards task creation across all keys. + # Per-key locks would be cleaner but require careful cleanup; + # contention here is minimal since task creation is very fast. + self._lock = asyncio.Lock() + + async def get( + self, + key: str, + populate_func: Callable[[], Coroutine[Any, Any, T]], + blocking: bool = True, + ) -> T: + """Retrieve a value from the cache, populating it if necessary. + + Checks the soft cache first. On a soft miss, kicks off a background + refresh and returns the stale hard-cache value if one exists. On a hard + miss, either awaits the refresh (blocking=True) or raises NotReadyError + (blocking=False). + + Args: + key (str): The cache key to retrieve or populate. + populate_func: An async callable (coroutine function) that returns + the value to cache. + blocking (bool): If True, wait for the populate_func to complete on + a hard miss. If False, raise NotReadyError instead. + + Returns: + The cached value associated with the key. + + """ + # Fast path: soft cache hit, no locking needed. + if key in self.soft_cache: + return self.soft_cache[key] + + async with self._lock: + # Re-check inside the lock in case another coroutine just populated it. + if key in self.soft_cache: + return self.soft_cache[key] + + # Ensure at most one refresh Task is in flight for this key. + if key not in self.tasks or self.tasks[key].done(): + self.tasks[key] = asyncio.create_task(self._work(key, populate_func)) + task = self.tasks[key] + + if key in self.hard_cache: + # Soft miss but hard hit: serve stale while the refresh runs. + # Pre-fill soft cache so the next request skips the lock entirely. + result = self.hard_cache[key] + self.soft_cache[key] = result + return result + + # Hard miss: no value in either cache yet. + if blocking: + # Await outside the lock so _work can acquire it to write results. + await task + return self.hard_cache[key] + + logger.debug( + "Cache key %r not ready yet, background population in progress", key + ) + raise NotReadyError(f"Cache key {key} is not ready yet.") + + async def _work( + self, key: str, populate_func: Callable[[], Coroutine[Any, Any, T]] + ) -> None: + """Await populate_func and write results into both cache tiers. + + Always removes the task entry so the next soft miss can schedule a fresh + refresh, regardless of whether this attempt succeeded or failed. + + Args: + key (str): The cache key to populate. + populate_func: Async callable that produces the value. + + """ + success = False + result = None + try: + result = await populate_func() + success = True + except Exception: + logger.error( + "Failed to populate cache key %r, will retry on next request", + key, + exc_info=True, + ) + raise + finally: + async with self._lock: + self.tasks.pop(key, None) + if success: + self.hard_cache[key] = result + self.soft_cache[key] = result + + async def clear(self): + """Cancel any in-flight refresh tasks and clear both cache tiers.""" + async with self._lock: + for task in self.tasks.values(): + task.cancel() + self.tasks.clear() + self.soft_cache.clear() + self.hard_cache.clear() + + async def batched_async( iterable: AsyncIterable[T], n: int, *, strict: bool = False ) -> AsyncIterable[tuple[T, ...]]: diff --git a/diracx-core/tests/test_utils.py b/diracx-core/tests/test_utils.py index b88483e39..42edc5346 100644 --- a/diracx-core/tests/test_utils.py +++ b/diracx-core/tests/test_utils.py @@ -11,6 +11,7 @@ from diracx.core.exceptions import NotReadyError from diracx.core.models import TokenResponse from diracx.core.utils import ( + AsyncTwoLevelCache, TwoLevelCache, dotenv_files_from_environment, read_credentials, @@ -299,3 +300,149 @@ def start_slow(): # Ensure background thread completes thread.join() + + +class TestAsyncTwoLevelCache: + """Tests for AsyncTwoLevelCache, mirroring TestTwoLevelCache.""" + + async def test_successful_population(self): + """Test that cache is populated successfully.""" + cache = AsyncTwoLevelCache(soft_ttl=10, hard_ttl=60) + call_count = 0 + + async def populate(): + nonlocal call_count + call_count += 1 + return "test_value" + + result = await cache.get("key", populate) + assert result == "test_value" + assert call_count == 1 + + # Second call should use cached value + result = await cache.get("key", populate) + assert result == "test_value" + assert call_count == 1 + + async def test_failed_population_logs_and_allows_retry(self, caplog): + """Test that failed population logs error and allows retry on next request.""" + cache = AsyncTwoLevelCache(soft_ttl=10, hard_ttl=60) + call_count = 0 + + async def failing_populate(): + nonlocal call_count + call_count += 1 + raise ValueError("Test error") + + # First call should fail and log the error + with pytest.raises(ValueError, match="Test error"): + await cache.get("key", failing_populate, blocking=True) + + assert call_count == 1 + assert "Failed to populate cache key 'key'" in caplog.text + assert "Test error" in caplog.text + + # Task should be removed, so next call should retry + with pytest.raises(ValueError, match="Test error"): + await cache.get("key", failing_populate, blocking=True) + + assert call_count == 2 # Should have retried + + async def test_failed_population_then_success(self): + """Test that after a failure, subsequent successful call works.""" + cache = AsyncTwoLevelCache(soft_ttl=10, hard_ttl=60) + should_fail = True + + async def populate(): + if should_fail: + raise ValueError("Test error") + return "success_value" + + # First call fails + with pytest.raises(ValueError): + await cache.get("key", populate, blocking=True) + + # Second call succeeds + should_fail = False + result = await cache.get("key", populate, blocking=True) + assert result == "success_value" + + async def test_none_return_value_is_cached(self): + """Test that None return values are properly cached.""" + cache = AsyncTwoLevelCache(soft_ttl=10, hard_ttl=60) + call_count = 0 + + async def populate_none(): + nonlocal call_count + call_count += 1 + return None + + result = await cache.get("key", populate_none) + assert result is None + assert call_count == 1 + + # Second call should use cached None value + result = await cache.get("key", populate_none) + assert result is None + assert call_count == 1 # Should not have called populate again + + async def test_non_blocking_raises_not_ready(self): + """Test that non-blocking mode raises NotReadyError when cache is empty.""" + import asyncio + + cache = AsyncTwoLevelCache(soft_ttl=10, hard_ttl=60) + started = asyncio.Event() + release = asyncio.Event() + + async def slow_populate(): + started.set() + await release.wait() + return "value" + + # Start population in the background + task = asyncio.create_task(cache.get("key", slow_populate, blocking=True)) + + # Wait until slow_populate has started to avoid race conditions + await asyncio.wait_for(started.wait(), timeout=1.0) + + # Non-blocking call should raise NotReadyError since cache isn't populated yet + with pytest.raises(NotReadyError, match="not ready yet"): + await cache.get("key", slow_populate, blocking=False) + + # Let the in-flight population finish + release.set() + assert await task == "value" + + async def test_single_flight_deduplication(self): + """Test that concurrent gets for the same key only populate once.""" + import asyncio + + cache = AsyncTwoLevelCache(soft_ttl=10, hard_ttl=60) + call_count = 0 + + async def populate(): + nonlocal call_count + call_count += 1 + await asyncio.sleep(0.05) + return "value" + + results = await asyncio.gather( + *(cache.get("key", populate, blocking=True) for _ in range(5)) + ) + assert results == ["value"] * 5 + assert call_count == 1 + + async def test_clear(self): + """Test that clear empties both cache tiers.""" + cache = AsyncTwoLevelCache(soft_ttl=10, hard_ttl=60) + call_count = 0 + + async def populate(): + nonlocal call_count + call_count += 1 + return "value" + + assert await cache.get("key", populate) == "value" + await cache.clear() + assert await cache.get("key", populate) == "value" + assert call_count == 2 diff --git a/diracx-db/src/diracx/db/sql/rss/db.py b/diracx-db/src/diracx/db/sql/rss/db.py index a891995f6..fc01ab96b 100644 --- a/diracx-db/src/diracx/db/sql/rss/db.py +++ b/diracx-db/src/diracx/db/sql/rss/db.py @@ -1,9 +1,9 @@ from __future__ import annotations -from sqlalchemy import select -from sqlalchemy.engine import Row +from datetime import datetime, timezone -from diracx.core.exceptions import ResourceNotFoundError +from sqlalchemy import func, insert, select +from sqlalchemy.engine import Row from ..utils import BaseSQLDB from .schema import ( @@ -18,37 +18,159 @@ class ResourceStatusDB(BaseSQLDB): metadata = RSSBase.metadata - async def get_site_status(self, name: str, vo: str = "all") -> tuple[str, str]: - stmt = select(SiteStatus.status, SiteStatus.reason).where( - SiteStatus.name == name, - SiteStatus.status_type == "all", - SiteStatus.vo == vo, - ) - result = await self.conn.execute(stmt) - row = result.one_or_none() - if not row: - raise ResourceNotFoundError(name) + async def get_site_statuses(self) -> list[tuple[str, str, str, str]]: + """Return all site statuses across all VOs. + + Returns: + List of (name, status, reason, vo) tuples. - return row.Status, row.Reason + """ + stmt = select( + SiteStatus.name, + SiteStatus.status, + SiteStatus.reason, + SiteStatus.vo, + ).where(SiteStatus.status_type == "all") + result = await self.conn.execute(stmt) + return [(row.Name, row.Status, row.Reason, row.VO) for row in result.all()] - async def get_resource_status( + async def get_resource_statuses( self, - name: str, status_types: list[str] | None = None, - vo: str = "all", - ) -> dict[str, Row]: + ) -> dict[str, dict[str, Row]]: + """Return resource statuses for the given status types across all VOs. + + Args: + status_types: Status type filter (e.g. ["ReadAccess", "WriteAccess"]). + Defaults to ["all"]. + + Returns: + Nested dict keyed by resource name then status type. + + """ if not status_types: status_types = ["all"] stmt = select( - ResourceStatus.status, ResourceStatus.reason, ResourceStatus.status_type + ResourceStatus.name, + ResourceStatus.status, + ResourceStatus.reason, + ResourceStatus.status_type, + ResourceStatus.vo, ).where( - ResourceStatus.name == name, ResourceStatus.status_type.in_(status_types), - ResourceStatus.vo == vo, ) result = await self.conn.execute(stmt) - rows = result.all() - if not rows: - raise ResourceNotFoundError(name) - return {row.StatusType: row for row in rows} + statuses: dict[str, dict[str, Row]] = {} + for row in result.all(): + if row.Name not in statuses: + statuses[row.Name] = {} + statuses[row.Name][row.StatusType] = row + return statuses + + async def get_resource_status_date( + self, + status_types: list[str] | None = None, + ) -> tuple[datetime | None, int]: + """Return the most recent DateEffective and row count for the given status types. + + Args: + status_types: Status type filter. Defaults to ["all"]. + + Returns: + (max_date_effective, row_count) across all VOs. The date is None + when the table contains no matching rows. + + """ + if not status_types: + status_types = ["all"] + stmt = select( + func.max(ResourceStatus.date_effective), + func.count(), + ).where(ResourceStatus.status_type.in_(status_types)) + result = await self.conn.execute(stmt) + max_date, count = result.one() + return max_date, count + + async def get_site_status_date(self) -> tuple[datetime | None, int]: + """Return the most recent DateEffective and row count from the SiteStatus table. + + Returns: + (max_date_effective, row_count) across all VOs. The date is None + when the table contains no matching rows. + + """ + stmt = select( + func.max(SiteStatus.date_effective), + func.count(), + ).where(SiteStatus.status_type == "all") + result = await self.conn.execute(stmt) + max_date, count = result.one() + return max_date, count + + async def insert_resource_status( + self, + name: str, + status: str, + status_type: str, + vo: str, + reason: str = "", + date_effective: datetime | None = None, + last_check_time: datetime | None = None, + ) -> None: + """Insert a single ResourceStatus row. + + Args: + name: Resource name. + status: Status value. + status_type: One of "all", "ReadAccess", "WriteAccess", etc. + vo: Virtual organisation (e.g. "lhcb", "all"). + reason: Human-readable reason string. + date_effective: Timestamp when the status became effective. + Defaults to now. + last_check_time: Timestamp of last check. Defaults to now. + + """ + now = datetime.now(timezone.utc) + stmt = insert(ResourceStatus).values( + Name=name, + Status=status, + StatusType=status_type, + VO=vo, + Reason=reason, + DateEffective=date_effective or now, + LastCheckTime=last_check_time or now, + ) + await self.conn.execute(stmt) + + async def insert_site_status( + self, + name: str, + status: str, + vo: str, + reason: str = "", + date_effective: datetime | None = None, + last_check_time: datetime | None = None, + ) -> None: + """Insert a single SiteStatus row. + + Args: + name: Site name (e.g. "LCG.CERN.cern"). + status: Status value (e.g. "Active", "Banned"). + vo: Virtual organisation. + reason: Human-readable reason string. + date_effective: Defaults to now. + last_check_time: Defaults to now. + + """ + now = datetime.now(timezone.utc) + stmt = insert(SiteStatus).values( + Name=name, + Status=status, + StatusType="all", + VO=vo, + Reason=reason, + DateEffective=date_effective or now, + LastCheckTime=last_check_time or now, + ) + await self.conn.execute(stmt) diff --git a/diracx-db/tests/rss/test_rss_db.py b/diracx-db/tests/rss/test_rss_db.py index 2956801ee..3457163c5 100644 --- a/diracx-db/tests/rss/test_rss_db.py +++ b/diracx-db/tests/rss/test_rss_db.py @@ -5,7 +5,6 @@ import pytest from sqlalchemy import insert -from diracx.core.exceptions import ResourceNotFoundError from diracx.db.sql.rss.db import ResourceStatusDB _NOW = datetime(2024, 1, 1, tzinfo=timezone.utc) @@ -41,14 +40,13 @@ async def test_site_status(rss_db: ResourceStatusDB): # Test with the test Site (should be found) async with rss_db as db: - status, reason = await db.get_site_status("TestSite") + rows = await db.get_site_statuses() + assert rows + name, status, reason, vo = rows[0] + assert name == "TestSite" assert status == "Active" assert reason == "All good" - - # Test with an unknow Site (should not be found) - with pytest.raises(ResourceNotFoundError): - async with rss_db as db: - await db.get_site_status("Unknown") + assert vo == "all" async def test_resource_status(rss_db: ResourceStatusDB): @@ -102,34 +100,71 @@ async def test_resource_status(rss_db: ResourceStatusDB): # Test with the test Compute Element (should be found) async with rss_db as db: - result = await db.get_resource_status("TestCompute") + result = await db.get_resource_statuses() + assert "TestCompute" in result + result = result["TestCompute"] assert "all" in result assert result["all"].Status == "Active" assert result["all"].Reason == "All good" + assert result["all"].VO == "all" # Test with the test FTS (should be found) async with rss_db as db: - result = await db.get_resource_status("TestFTS") + result = await db.get_resource_statuses() + assert "TestFTS" in result + result = result["TestFTS"] assert "all" in result assert result["all"].Status == "Active" assert result["all"].Reason == "All good" + assert result["all"].VO == "all" # Test with the test Storage Element (should be found) async with rss_db as db: - result = await db.get_resource_status( - "TestStorage", ["ReadAccess", "WriteAccess", "CheckAccess", "RemoveAccess"] + result = await db.get_resource_statuses( + ["ReadAccess", "WriteAccess", "CheckAccess", "RemoveAccess"] ) - assert set(result.keys()) == { + assert set(result["TestStorage"].keys()) == { "ReadAccess", "WriteAccess", "CheckAccess", "RemoveAccess", } - for row in result.values(): + for row in result["TestStorage"].values(): assert row.Status == "Active" assert row.Reason == "All good" - # Test with an unknow Resource (should not be found) - with pytest.raises(ResourceNotFoundError): - async with rss_db as db: - await db.get_resource_status("Unknown") + # The date queries should return the latest date and the row count + async with rss_db as db: + max_date, count = await db.get_resource_status_date() + assert max_date == _NOW + assert count == 2 # TestCompute + TestFTS "all" rows + + async with rss_db as db: + max_date, count = await db.get_resource_status_date( + ["ReadAccess", "WriteAccess", "CheckAccess", "RemoveAccess"] + ) + assert max_date == _NOW + assert count == 4 # TestStorage access rows + + +async def test_empty_tables(rss_db: ResourceStatusDB): + """Empty tables yield empty results rather than errors.""" + async with rss_db as db: + assert await db.get_site_statuses() == [] + assert await db.get_resource_statuses() == {} + assert await db.get_resource_status_date() == (None, 0) + assert await db.get_site_status_date() == (None, 0) + + +async def test_site_status_date(rss_db: ResourceStatusDB): + async with rss_db as db: + await db.insert_site_status( + name="LCG.CERN.cern", + status="Active", + vo="lhcb", + reason="All good", + date_effective=_NOW, + ) + max_date, count = await db.get_site_status_date() + assert max_date == _NOW + assert count == 1 diff --git a/diracx-logic/pyproject.toml b/diracx-logic/pyproject.toml index 889fe9a55..0ddfb5949 100644 --- a/diracx-logic/pyproject.toml +++ b/diracx-logic/pyproject.toml @@ -23,6 +23,12 @@ dependencies = [ ] dynamic = ["version"] +[project.entry-points."diracx.cacheable_sources"] +storage_status = "diracx.logic.rss.source:StorageElementStatusSource" +compute_status = "diracx.logic.rss.source:ComputeElementStatusSource" +fts_status = "diracx.logic.rss.source:FTSStatusSource" +site_status = "diracx.logic.rss.source:SiteStatusSource" + [project.optional-dependencies] testing = ["diracx-testing", "freezegun"] types = [ diff --git a/diracx-logic/src/diracx/logic/rss/query.py b/diracx-logic/src/diracx/logic/rss/query.py index 7ce83cbce..cd6a87d46 100644 --- a/diracx-logic/src/diracx/logic/rss/query.py +++ b/diracx-logic/src/diracx/logic/rss/query.py @@ -35,36 +35,81 @@ def map_status(db_status: str, reason: str | None = None) -> ResourceStatus: ) -async def get_site_status( - resource_status_db: ResourceStatusDB, name: str, vo: str -) -> SiteStatusModel: - status, reason = await resource_status_db.get_site_status(name, vo) - return SiteStatusModel(all=map_status(status, reason)) +async def get_site_statuses( + resource_status_db: ResourceStatusDB, +) -> dict[str, dict[str, SiteStatusModel]]: + """Fetch all site statuses across all VOs. + The returned models carry the vo field so the router can filter to the + caller's VO from the cached all-VO snapshot. + """ + rows = await resource_status_db.get_site_statuses() -async def get_compute_status( - resource_status_db: ResourceStatusDB, name: str, vo: str -) -> ComputeElementStatus: - rows = await resource_status_db.get_resource_status(name, ["all"], vo) - return ComputeElementStatus(all=map_status(rows["all"].Status, rows["all"].Reason)) + result: dict[str, dict[str, SiteStatusModel]] = {} + for name, status, reason, vo in rows: + vo = vo or "all" + if vo not in result: + result[vo] = {} + result[vo][name] = SiteStatusModel(all=map_status(status, reason)) -async def get_fts_status( - resource_status_db: ResourceStatusDB, name: str, vo: str -) -> FTSStatus: - rows = await resource_status_db.get_resource_status(name, ["all"], vo) - return FTSStatus(all=map_status(rows["all"].Status, rows["all"].Reason)) + return result -async def get_storage_status( - resource_status_db: ResourceStatusDB, name: str, vo: str -) -> StorageElementStatus: - rows = await resource_status_db.get_resource_status( - name, ["ReadAccess", "WriteAccess", "CheckAccess", "RemoveAccess"], vo - ) - return StorageElementStatus( - read=map_status(rows["ReadAccess"].Status, rows["ReadAccess"].Reason), - write=map_status(rows["WriteAccess"].Status, rows["WriteAccess"].Reason), - check=map_status(rows["CheckAccess"].Status, rows["CheckAccess"].Reason), - remove=map_status(rows["RemoveAccess"].Status, rows["RemoveAccess"].Reason), +async def get_compute_statuses( + resource_status_db: ResourceStatusDB, +) -> dict[str, dict[str, ComputeElementStatus]]: + """Fetch all compute element statuses across all VOs.""" + all_rows = await resource_status_db.get_resource_statuses(["all"]) + + result: dict[str, dict[str, ComputeElementStatus]] = {} + for name, rows in all_rows.items(): + vo = rows["all"].VO or "all" + if vo not in result: + result[vo] = {} + result[vo][name] = ComputeElementStatus( + all=map_status(rows["all"].Status, rows["all"].Reason) + ) + + return result + + +async def get_fts_statuses( + resource_status_db: ResourceStatusDB, +) -> dict[str, dict[str, FTSStatus]]: + """Fetch all FTS server statuses across all VOs.""" + all_rows = await resource_status_db.get_resource_statuses(["all"]) + + result: dict[str, dict[str, FTSStatus]] = {} + for name, rows in all_rows.items(): + vo = rows["all"].VO or "all" + if vo not in result: + result[vo] = {} + result[vo][name] = FTSStatus( + all=map_status(rows["all"].Status, rows["all"].Reason) + ) + + return result + + +async def get_storage_statuses( + resource_status_db: ResourceStatusDB, +) -> dict[str, dict[str, StorageElementStatus]]: + """Fetch all storage element statuses across all VOs.""" + all_rows = await resource_status_db.get_resource_statuses( + ["ReadAccess", "WriteAccess", "CheckAccess", "RemoveAccess"] ) + + result: dict[str, dict[str, StorageElementStatus]] = {} + for name, rows in all_rows.items(): + vo = rows["ReadAccess"].VO or "all" + if vo not in result: + result[vo] = {} + result[vo][name] = StorageElementStatus( + read=map_status(rows["ReadAccess"].Status, rows["ReadAccess"].Reason), + write=map_status(rows["WriteAccess"].Status, rows["WriteAccess"].Reason), + check=map_status(rows["CheckAccess"].Status, rows["CheckAccess"].Reason), + remove=map_status(rows["RemoveAccess"].Status, rows["RemoveAccess"].Reason), + ) + + return result diff --git a/diracx-logic/src/diracx/logic/rss/source.py b/diracx-logic/src/diracx/logic/rss/source.py new file mode 100644 index 000000000..61091eab2 --- /dev/null +++ b/diracx-logic/src/diracx/logic/rss/source.py @@ -0,0 +1,120 @@ +"""Resource Status System source classes. + +These classes live in diracx-logic so they can import from diracx-db without +violating the project's dependency flow: + + routers → logic → db → core +""" + +from __future__ import annotations + +import logging +from abc import abstractmethod +from datetime import datetime, timezone +from typing import ClassVar + +from diracx.core.config.sources import AsyncCacheableSource, Snapshot +from diracx.db.sql.rss.db import ResourceStatusDB + +from .query import ( + get_compute_statuses, + get_fts_statuses, + get_site_statuses, + get_storage_statuses, +) + +logger = logging.getLogger(__name__) + +#: Revision returned when the underlying table contains no rows. +EMPTY_REVISION = ("empty-0", datetime(1970, 1, 1, tzinfo=timezone.utc)) + + +def _make_revision(max_date: datetime | None, count: int) -> tuple[str, datetime]: + """Build a (revision, modified) pair from the latest date and row count. + + Including the row count in the revision means insertions and deletions + change the ETag even when they do not advance the latest DateEffective. + """ + if max_date is None: + return EMPTY_REVISION + return f"{max_date.isoformat()}-{count}", max_date + + +class ResourceStatusSource(AsyncCacheableSource[Snapshot]): + """Base caching source for Compute, Storage and FTS resource types. + + Subclasses declare the status types their data lives in and how to fetch + it from the database. + + One source instance per resource type covers all VOs. VO filtering is done + in the route after the snapshot is fetched from the cache. + """ + + db_class = ResourceStatusDB + + #: Status types holding this resource type's data, used both for the + #: revision query and the data fetch. + status_types: ClassVar[list[str]] + + def __init__(self, *, db: ResourceStatusDB) -> None: + super().__init__() + self._db = db + + async def latest_revision(self) -> tuple[str, datetime]: + async with self._db as db: + max_date, count = await db.get_resource_status_date(self.status_types) + return _make_revision(max_date, count) + + async def read_raw(self, hexsha: str, modified: datetime) -> Snapshot: + async with self._db as db: + data = await self._fetch(db) + return Snapshot(data=data, hexsha=hexsha, modified=modified) + + @abstractmethod + async def _fetch(self, db: ResourceStatusDB) -> dict: + """Fetch this resource type's statuses, keyed by VO then name.""" + + +class StorageElementStatusSource(ResourceStatusSource): + status_types = ["ReadAccess", "WriteAccess", "CheckAccess", "RemoveAccess"] + + async def _fetch(self, db: ResourceStatusDB) -> dict: + return await get_storage_statuses(db) + + +class ComputeElementStatusSource(ResourceStatusSource): + status_types = ["all"] + + async def _fetch(self, db: ResourceStatusDB) -> dict: + return await get_compute_statuses(db) + + +class FTSStatusSource(ResourceStatusSource): + status_types = ["all"] + + async def _fetch(self, db: ResourceStatusDB) -> dict: + return await get_fts_statuses(db) + + +class SiteStatusSource(AsyncCacheableSource[Snapshot]): + """Caching source for Site statuses. + + Uses its own DB table (SiteStatus) and a dedicated date query, so it is a + direct subclass of AsyncCacheableSource rather than ResourceStatusSource. + """ + + db_class = ResourceStatusDB + + def __init__(self, *, db: ResourceStatusDB) -> None: + super().__init__() + self._db = db + + async def latest_revision(self) -> tuple[str, datetime]: + async with self._db as db: + max_date, count = await db.get_site_status_date() + return _make_revision(max_date, count) + + async def read_raw(self, hexsha: str, modified: datetime) -> Snapshot: + async with self._db as db: + data = await get_site_statuses(db) + return Snapshot(data=data, hexsha=hexsha, modified=modified) diff --git a/diracx-logic/tests/rss/test_rss.py b/diracx-logic/tests/rss/test_rss_query.py similarity index 100% rename from diracx-logic/tests/rss/test_rss.py rename to diracx-logic/tests/rss/test_rss_query.py diff --git a/diracx-logic/tests/rss/test_rss_source.py b/diracx-logic/tests/rss/test_rss_source.py new file mode 100644 index 000000000..503c364c2 --- /dev/null +++ b/diracx-logic/tests/rss/test_rss_source.py @@ -0,0 +1,202 @@ +from __future__ import annotations + +from collections import namedtuple +from datetime import datetime, timezone +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from diracx.core.models.rss import ( + ComputeElementStatus, + FTSStatus, + SiteStatus, + StorageElementStatus, +) +from diracx.db.sql.rss.db import ResourceStatusDB +from diracx.logic.rss.source import ( + ComputeElementStatusSource, + FTSStatusSource, + SiteStatusSource, + StorageElementStatusSource, +) + +_MAX_DATE = datetime.fromisoformat("2023-01-01T00:00:00+00:00") + + +@pytest.fixture +def mock_resource_status_db(): + """Fixture to mock the ResourceStatusDB.""" + db = MagicMock(spec=ResourceStatusDB) + db.__aenter__ = AsyncMock(return_value=db) + db.__aexit__ = AsyncMock(return_value=None) + db.get_resource_status_date = AsyncMock(return_value=(_MAX_DATE, 4)) + db.get_site_status_date = AsyncMock(return_value=(_MAX_DATE, 2)) + return db + + +async def test_latest_revision(mock_resource_status_db): + """Test the latest_revision method of ResourceStatusSource.""" + source = ComputeElementStatusSource(db=mock_resource_status_db) + + # Call the method + revision, modified = await source.latest_revision() + + # Verify the revision is generated correctly + assert revision == f"{_MAX_DATE.isoformat()}-4" + assert modified == _MAX_DATE + + # Verify the database call queries this source's status types + mock_resource_status_db.get_resource_status_date.assert_awaited_once_with(["all"]) + + +async def test_latest_revision_storage_status_types(mock_resource_status_db): + """Storage revisions must track the access status types, not "all".""" + source = StorageElementStatusSource(db=mock_resource_status_db) + + await source.latest_revision() + + mock_resource_status_db.get_resource_status_date.assert_awaited_once_with( + ["ReadAccess", "WriteAccess", "CheckAccess", "RemoveAccess"] + ) + + +async def test_latest_revision_empty(mock_resource_status_db): + """An empty table yields a stable sentinel revision instead of failing.""" + mock_resource_status_db.get_resource_status_date = AsyncMock(return_value=(None, 0)) + source = ComputeElementStatusSource(db=mock_resource_status_db) + + revision, modified = await source.latest_revision() + + assert revision == "empty-0" + assert modified == datetime(1970, 1, 1, tzinfo=timezone.utc) + + +async def test_latest_revision_site(mock_resource_status_db): + """Test the latest_revision method of SiteStatusSource.""" + source = SiteStatusSource(db=mock_resource_status_db) + + revision, modified = await source.latest_revision() + + assert revision == f"{_MAX_DATE.isoformat()}-2" + assert modified == _MAX_DATE + mock_resource_status_db.get_site_status_date.assert_awaited_once_with() + + +async def test_read_raw_site(mock_resource_status_db): + """Test the read_raw method for Site resource type.""" + # Mock the database data + mock_db_data = [("testSite", "Active", "", "test_vo")] + + # Patch the get_site_statuses method of the database to return the mock data + mock_resource_status_db.get_site_statuses = AsyncMock(return_value=mock_db_data) + + # Initialize the ResourceStatusSource with the mocked database + source = SiteStatusSource(db=mock_resource_status_db) + + # Call the read_raw method, which internally calls get_site_statuses from query.py + result = await source.read_raw("test_revision", datetime.now(tz=timezone.utc)) + + # Verify the result matches the expected output + expected_result = {"testSite": SiteStatus(all={"allowed": True, "warnings": None})} + for key, value in expected_result.items(): + assert key in result.data["test_vo"] + assert value.model_dump() == result.data["test_vo"][key].model_dump() + # Verify that the database method was called correctly + mock_resource_status_db.get_site_statuses.assert_awaited_once() + + +async def test_read_raw_compute(mock_resource_status_db): + """Test the read_raw method for ComputeElement resource type.""" + ResourceStatus = namedtuple("ResourceStatus", ["Name", "Status", "Reason", "VO"]) + + mock_db_data = { + "TestCE": { + "all": ResourceStatus( + Name="TestCE", Status="Active", Reason="", VO="test_vo" + ) + } + } + mock_resource_status_db.get_resource_statuses = AsyncMock(return_value=mock_db_data) + + source = ComputeElementStatusSource(db=mock_resource_status_db) + + # Call the method + result = await source.read_raw("test_revision", datetime.now(tz=timezone.utc)) + + # Verify the result + expected_result = { + "TestCE": ComputeElementStatus(all={"allowed": True, "warnings": None}) + } + for key, value in expected_result.items(): + assert key in result.data["test_vo"] + assert value.model_dump() == result.data["test_vo"][key].model_dump() + mock_resource_status_db.get_resource_statuses.assert_awaited_once_with(["all"]) + + +async def test_read_raw_storage(mock_resource_status_db): + """Test the read_raw method for StorageElement resource type.""" + ResourceStatus = namedtuple("ResourceStatus", ["Name", "Status", "Reason", "VO"]) + + mock_db_data = { + "TestSE": { + "ReadAccess": ResourceStatus( + Name="TestSE", Status="Active", Reason=None, VO="test_vo" + ), + "WriteAccess": ResourceStatus( + Name="TestSE", Status="Active", Reason=None, VO="test_vo" + ), + "CheckAccess": ResourceStatus( + Name="TestSE", Status="Active", Reason=None, VO="test_vo" + ), + "RemoveAccess": ResourceStatus( + Name="TestSE", Status="Active", Reason=None, VO="test_vo" + ), + } + } + mock_resource_status_db.get_resource_statuses.return_value = mock_db_data + source = StorageElementStatusSource(db=mock_resource_status_db) + + # Call the method + result = await source.read_raw("test_revision", datetime.now(tz=timezone.utc)) + + # Verify the result + expected_result = { + "TestSE": StorageElementStatus( + read={"allowed": True, "warnings": None}, + write={"allowed": True, "warnings": None}, + check={"allowed": True, "warnings": None}, + remove={"allowed": True, "warnings": None}, + ) + } + for key, value in expected_result.items(): + assert key in result.data["test_vo"] + assert value.model_dump() == result.data["test_vo"][key].model_dump() + mock_resource_status_db.get_resource_statuses.assert_awaited_once_with( + ["ReadAccess", "WriteAccess", "CheckAccess", "RemoveAccess"] + ) + + +async def test_read_raw_fts(mock_resource_status_db): + """Test the read_raw method for FTS resource type.""" + ResourceStatus = namedtuple("ResourceStatus", ["Name", "Status", "Reason", "VO"]) + + mock_db_data = { + "FTS": { + "all": ResourceStatus( + Name="FTS", Status="Active", Reason=None, VO="test_vo" + ), + } + } + mock_resource_status_db.get_resource_statuses.return_value = mock_db_data + + source = FTSStatusSource(db=mock_resource_status_db) + + # Call the method + result = await source.read_raw("test_revision", datetime.now(tz=timezone.utc)) + + # Verify the result + expected_result = {"FTS": FTSStatus(all={"allowed": True, "warnings": None})} + for key, value in expected_result.items(): + assert key in result.data["test_vo"] + assert value.model_dump() == result.data["test_vo"][key].model_dump() + mock_resource_status_db.get_resource_statuses.assert_awaited_once_with(["all"]) diff --git a/diracx-routers/pyproject.toml b/diracx-routers/pyproject.toml index d798b1696..26e49317f 100644 --- a/diracx-routers/pyproject.toml +++ b/diracx-routers/pyproject.toml @@ -47,10 +47,12 @@ auth = "diracx.routers.auth:router" config = "diracx.routers.configuration:router" health = "diracx.routers.health:router" jobs = "diracx.routers.jobs:router" +rss = "diracx.routers.rss:router" [project.entry-points."diracx.access_policies"] wms = "diracx.routers.jobs.access_policies:WMSAccessPolicy" sandbox = "diracx.routers.jobs.access_policies:SandboxAccessPolicy" +rss = "diracx.routers.rss.access_policies:RSSAccessPolicy" # Minimum version of the client supported [project.entry-points."diracx.min_client_version"] @@ -87,5 +89,4 @@ markers = [ "enabled_dependencies: List of dependencies which should be available to the FastAPI test client", ] - asyncio_default_fixture_loop_scope = "function" diff --git a/diracx-routers/src/diracx/routers/configuration.py b/diracx-routers/src/diracx/routers/configuration.py index 63d4ea199..16f5a6536 100644 --- a/diracx-routers/src/diracx/routers/configuration.py +++ b/diracx-routers/src/diracx/routers/configuration.py @@ -3,13 +3,10 @@ __all__ = ["router"] import logging -from datetime import datetime, timezone -from http import HTTPStatus from typing import Annotated from fastapi import ( Header, - HTTPException, Response, ) @@ -17,11 +14,10 @@ from .access_policies import open_access from .fastapi_classes import DiracxRouter +from .utils.http_cache import apply_cache_headers logger = logging.getLogger(__name__) -LAST_MODIFIED_FORMAT = "%a, %d %b %Y %H:%M:%S GMT" - router = DiracxRouter() @@ -42,31 +38,12 @@ async def serve_config( return 304: this is to avoid flip/flopping """ # await check_permissions() - headers = { - "ETag": config._hexsha, - "Last-Modified": config._modified.strftime(LAST_MODIFIED_FORMAT), - } - - if if_none_match == config._hexsha: - raise HTTPException(status_code=HTTPStatus.NOT_MODIFIED, headers=headers) - - # This is to prevent flip/flopping in case - # a server gets out of sync with disk - if if_modified_since: - try: - not_before = datetime.strptime( - if_modified_since, LAST_MODIFIED_FORMAT - ).astimezone(timezone.utc) - except ValueError: - logger.debug( - "Failed to parse If-Modified-Since header: %s", if_modified_since - ) - else: - if not_before > config._modified: - raise HTTPException( - status_code=HTTPStatus.NOT_MODIFIED, headers=headers - ) - - response.headers.update(headers) + apply_cache_headers( + response, + etag=config._hexsha, + modified=config._modified, + if_none_match=if_none_match, + if_modified_since=if_modified_since, + ) return config diff --git a/diracx-routers/src/diracx/routers/factory.py b/diracx-routers/src/diracx/routers/factory.py index 4c26f4c58..b33490be3 100644 --- a/diracx-routers/src/diracx/routers/factory.py +++ b/diracx-routers/src/diracx/routers/factory.py @@ -185,6 +185,7 @@ def create_app_inner( fail_startup = True # Add the SQL DBs to the application available_sql_db_classes: set[type[BaseSQLDB]] = set() + sql_db_instances: dict[type[BaseSQLDB], BaseSQLDB] = {} for db_name, db_url in database_urls.items(): try: @@ -199,6 +200,7 @@ def create_app_inner( for sql_db_class in sql_db_classes: assert sql_db_class.transaction not in app.dependency_overrides available_sql_db_classes.add(sql_db_class) + sql_db_instances[sql_db_class] = sql_db app.dependency_overrides[sql_db_class.transaction] = partial( db_transaction, sql_db @@ -215,6 +217,28 @@ def create_app_inner( if fail_startup: raise Exception("No SQL database could be initialized, aborting") + # Instantiate the cacheable sources and override their create methods, + # mirroring the ConfigSource wiring above. A single instance is used for + # each source so that its caches persist across requests. + wired_source_names = set() + for entry_point in select_from_extension(group=DiracEntryPoint.CACHEABLE_SOURCES): + # The first entry point for a given name is the highest priority one + if entry_point.name in wired_source_names: + continue + wired_source_names.add(entry_point.name) + source_cls = entry_point.load() + source_db = sql_db_instances.get(source_cls.db_class) + if source_db is None: + logger.warning( + "Cannot wire cacheable source %s: %s is not available", + entry_point.name, + source_cls.db_class.__name__, + ) + continue + source = source_cls(db=source_db) + assert source_cls.create not in app.dependency_overrides + app.dependency_overrides[source_cls.create] = source.read + # Add the OpenSearch DBs to the application available_os_db_classes: set[type[BaseOSDB]] = set() for db_name, connection_kwargs in os_database_conn_kwargs.items(): diff --git a/diracx-routers/src/diracx/routers/rss/__init__.py b/diracx-routers/src/diracx/routers/rss/__init__.py new file mode 100644 index 000000000..1f8b8659e --- /dev/null +++ b/diracx-routers/src/diracx/routers/rss/__init__.py @@ -0,0 +1,10 @@ +from __future__ import annotations + +__all__ = ["RSSAccessPolicy", "router"] + +from ..fastapi_classes import DiracxRouter +from .access_policies import RSSAccessPolicy +from .rss import router as rss_router + +router = DiracxRouter() +router.include_router(rss_router) diff --git a/diracx-routers/src/diracx/routers/rss/access_policies.py b/diracx-routers/src/diracx/routers/rss/access_policies.py new file mode 100644 index 000000000..e30e52b32 --- /dev/null +++ b/diracx-routers/src/diracx/routers/rss/access_policies.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +from collections.abc import Callable +from typing import Annotated + +from fastapi import Depends + +from diracx.routers.access_policies import BaseAccessPolicy +from diracx.routers.utils.users import AuthorizedUserInfo + + +class RSSAccessPolicy(BaseAccessPolicy): + """Any authenticated user can access.""" + + @staticmethod + async def policy( + policy_name: str, + user_info: AuthorizedUserInfo, + /, + ): + # Authentication is already guaranteed by verify_dirac_access_token; + # any authenticated user may read resource statuses. VO scoping is + # applied in the routes themselves. + return + + +CheckRSSPolicyCallable = Annotated[Callable, Depends(RSSAccessPolicy.check)] diff --git a/diracx-routers/src/diracx/routers/rss/rss.py b/diracx-routers/src/diracx/routers/rss/rss.py new file mode 100644 index 000000000..ee7adfff5 --- /dev/null +++ b/diracx-routers/src/diracx/routers/rss/rss.py @@ -0,0 +1,106 @@ +from __future__ import annotations + +import logging +from typing import Annotated, Any + +from fastapi import Depends, Header, Response + +from diracx.core.config.sources import Snapshot +from diracx.core.models.rss import ( + ComputeElementStatus, + FTSStatus, + SiteStatus, + StorageElementStatus, +) +from diracx.logic.rss.source import ( + ComputeElementStatusSource, + FTSStatusSource, + SiteStatusSource, + StorageElementStatusSource, +) +from diracx.routers.utils.users import AuthorizedUserInfo, verify_dirac_access_token + +from ..fastapi_classes import DiracxRouter +from ..utils.http_cache import apply_cache_headers +from .access_policies import CheckRSSPolicyCallable + +logger = logging.getLogger(__name__) + +router = DiracxRouter() + + +def _vo_view( + snapshot: Snapshot, + vo: str, + response: Response, + if_none_match: str | None, + if_modified_since: str | None, +) -> dict[str, Any]: + """Apply cache headers and return the caller's VO view of a snapshot. + + The snapshot covers all VOs so it can be cached once; the response is the + "all" entries overlaid with the caller's VO-specific entries. The ETag is + suffixed with the VO (and Vary: Authorization set) since the same URL + serves different content per VO. + """ + apply_cache_headers( + response, + etag=f"{snapshot.hexsha}-{vo}", + modified=snapshot.modified, + if_none_match=if_none_match, + if_modified_since=if_modified_since, + vary="Authorization", + ) + return {**snapshot.data.get("all", {}), **snapshot.data.get(vo, {})} + + +@router.get("/storage") +async def get_storage_status( + response: Response, + user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], + check_permissions: CheckRSSPolicyCallable, + snapshot: Annotated[Snapshot, Depends(StorageElementStatusSource.create)], + if_none_match: Annotated[str | None, Header()] = None, + if_modified_since: Annotated[str | None, Header()] = None, +) -> dict[str, StorageElementStatus]: + """Get the latest status of storage elements, scoped to the caller's VO.""" + return _vo_view(snapshot, user_info.vo, response, if_none_match, if_modified_since) + + +@router.get("/compute") +async def get_compute_status( + response: Response, + user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], + check_permissions: CheckRSSPolicyCallable, + snapshot: Annotated[Snapshot, Depends(ComputeElementStatusSource.create)], + if_none_match: Annotated[str | None, Header()] = None, + if_modified_since: Annotated[str | None, Header()] = None, +) -> dict[str, ComputeElementStatus]: + """Get the latest status of compute elements, scoped to the caller's VO.""" + return _vo_view(snapshot, user_info.vo, response, if_none_match, if_modified_since) + + +@router.get("/site") +async def get_site_status( + response: Response, + user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], + check_permissions: CheckRSSPolicyCallable, + snapshot: Annotated[Snapshot, Depends(SiteStatusSource.create)], + if_none_match: Annotated[str | None, Header()] = None, + if_modified_since: Annotated[str | None, Header()] = None, +) -> dict[str, SiteStatus]: + """Get the latest status of sites, scoped to the caller's VO.""" + return _vo_view(snapshot, user_info.vo, response, if_none_match, if_modified_since) + + +@router.get("/fts") +async def get_fts_status( + response: Response, + user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], + check_permissions: CheckRSSPolicyCallable, + snapshot: Annotated[Snapshot, Depends(FTSStatusSource.create)], + if_none_match: Annotated[str | None, Header()] = None, + if_modified_since: Annotated[str | None, Header()] = None, +) -> dict[str, FTSStatus]: + """Get the latest status of FTS servers, scoped to the caller's VO.""" + return _vo_view(snapshot, user_info.vo, response, if_none_match, if_modified_since) diff --git a/diracx-routers/src/diracx/routers/utils/__init__.py b/diracx-routers/src/diracx/routers/utils/__init__.py index a8c5919dc..260a03638 100644 --- a/diracx-routers/src/diracx/routers/utils/__init__.py +++ b/diracx-routers/src/diracx/routers/utils/__init__.py @@ -1,5 +1,11 @@ from __future__ import annotations -__all__ = ["AuthorizedUserInfo", "verify_dirac_access_token"] +__all__ = [ + "LAST_MODIFIED_FORMAT", + "AuthorizedUserInfo", + "apply_cache_headers", + "verify_dirac_access_token", +] +from .http_cache import LAST_MODIFIED_FORMAT, apply_cache_headers from .users import AuthorizedUserInfo, verify_dirac_access_token diff --git a/diracx-routers/src/diracx/routers/utils/http_cache.py b/diracx-routers/src/diracx/routers/utils/http_cache.py new file mode 100644 index 000000000..7966d0cba --- /dev/null +++ b/diracx-routers/src/diracx/routers/utils/http_cache.py @@ -0,0 +1,72 @@ +"""Helpers for HTTP conditional-request caching (ETag / Last-Modified / 304).""" + +from __future__ import annotations + +import logging +from datetime import datetime, timezone +from http import HTTPStatus + +from fastapi import HTTPException, Response + +logger = logging.getLogger(__name__) + +LAST_MODIFIED_FORMAT = "%a, %d %b %Y %H:%M:%S GMT" + + +def apply_cache_headers( + response: Response, + *, + etag: str, + modified: datetime, + if_none_match: str | None, + if_modified_since: str | None, + vary: str | None = None, +) -> None: + """Set ETag / Last-Modified headers and raise 304 when appropriate. + + If If-None-Match matches the current ETag, return 304. + + If If-Modified-Since is given and is newer than the current Last-Modified, + return 304: this is to avoid flip/flopping in case a server gets out of + sync with the source of truth. + + Args: + response: The response whose headers should be updated. + etag: The current entity tag. + modified: The current modification time (timezone-aware). + if_none_match: Value of the If-None-Match request header, if any. + if_modified_since: Value of the If-Modified-Since request header, if any. + vary: Optional value for the Vary header, for responses whose content + depends on more than the URL (e.g. the caller's identity). + + Raises: + HTTPException(304): when the client's cached copy is still current. + + """ + headers = { + "ETag": etag, + "Last-Modified": modified.strftime(LAST_MODIFIED_FORMAT), + } + if vary is not None: + headers["Vary"] = vary + + if if_none_match == etag: + raise HTTPException(status_code=HTTPStatus.NOT_MODIFIED, headers=headers) + + if if_modified_since: + try: + # The If-Modified-Since header is always GMT (RFC 9110) + not_before = datetime.strptime( + if_modified_since, LAST_MODIFIED_FORMAT + ).replace(tzinfo=timezone.utc) + except ValueError: + logger.debug( + "Failed to parse If-Modified-Since header: %s", if_modified_since + ) + else: + if not_before > modified: + raise HTTPException( + status_code=HTTPStatus.NOT_MODIFIED, headers=headers + ) + + response.headers.update(headers) diff --git a/diracx-routers/tests/test_rss.py b/diracx-routers/tests/test_rss.py new file mode 100644 index 000000000..64df9c5fa --- /dev/null +++ b/diracx-routers/tests/test_rss.py @@ -0,0 +1,282 @@ +from __future__ import annotations + +from datetime import datetime, timezone +from http import HTTPStatus + +import pytest + +pytestmark = pytest.mark.enabled_dependencies( + [ + "AuthSettings", + "ResourceStatusDB", + "SiteStatusSource", + "FTSStatusSource", + "ComputeElementStatusSource", + "StorageElementStatusSource", + "RSSAccessPolicy", + "DevelopmentSettings", + ] +) + +ALL_ENDPOINTS = [ + "/api/rss/storage", + "/api/rss/compute", + "/api/rss/site", + "/api/rss/fts", +] + + +def _get_rss_db(client): + from diracx.db.sql.rss.db import ResourceStatusDB + + db_override = client.app.dependency_overrides[ResourceStatusDB.no_transaction] + # factory.py stores partial(db_no_transaction, db_instance); args[0] is the instance. + return db_override.args[0] + + +async def _clear_source_caches(client): + """Clear the singleton sources' caches. + + The sources live for the whole test session while each test gets a fresh + database, so any snapshot cached by a previous test must be dropped. + """ + from diracx.core.config.sources import AsyncCacheableSource + + for override in client.app.dependency_overrides.values(): + source = getattr(override, "__self__", None) + if isinstance(source, AsyncCacheableSource): + await source.clear_caches() + + +async def _prepare_rss(client): + """Reset the source caches and seed the database.""" + await _clear_source_caches(client) + + db = _get_rss_db(client) + now = datetime.now(tz=timezone.utc) + + async with db as conn: + for status_type in ("ReadAccess", "WriteAccess", "CheckAccess", "RemoveAccess"): + await conn.insert_resource_status( + name="SE-CERN", + status="Active", + status_type=status_type, + vo="lhcb", + reason="All good", + date_effective=now, + ) + # A storage element belonging to another VO, which the test user + # (vo=lhcb) must not see. + await conn.insert_resource_status( + name="SE-OTHER", + status="Active", + status_type=status_type, + vo="other_vo", + reason="All good", + date_effective=now, + ) + await conn.insert_resource_status( + name="CE-CERN", + status="Active", + status_type="all", + vo="lhcb", + reason="All good", + date_effective=now, + ) + await conn.insert_resource_status( + name="FTS-CERN", + status="Active", + status_type="all", + vo="lhcb", + reason="All good", + date_effective=now, + ) + await conn.insert_site_status( + name="LCG.CERN.cern", + status="Active", + vo="lhcb", + reason="All good", + date_effective=now, + ) + # A site visible to every VO. + await conn.insert_site_status( + name="LCG.Shared.ch", + status="Active", + vo="all", + reason="All good", + date_effective=now, + ) + + +@pytest.fixture +def normal_user_client(client_factory): + with client_factory.normal_user() as client: + # Run on the TestClient's portal so async primitives are bound to the + # same event loop that serves the requests. + client.portal.call(_prepare_rss, client) + yield client + + +@pytest.fixture +def empty_db_client(client_factory): + with client_factory.normal_user() as client: + client.portal.call(_clear_source_caches, client) + yield client + + +@pytest.fixture +def unauthenticated_client(client_factory): + with client_factory.unauthenticated() as client: + yield client + + +def test_unauthenticated(unauthenticated_client): + response = unauthenticated_client.get("/api/rss/storage") + assert response.status_code == HTTPStatus.UNAUTHORIZED + + +@pytest.mark.parametrize("endpoint", ALL_ENDPOINTS) +def test_get_resource_status(normal_user_client, endpoint): + r = normal_user_client.get(endpoint) + assert r.status_code == HTTPStatus.OK, r.json() + assert r.json(), r.text + + last_modified = r.headers["Last-Modified"] + etag = r.headers["ETag"] + # The same URL serves different content per VO, so the ETag must identify + # the VO and caches must be told the response varies with the caller. + assert etag.endswith("-lhcb") + assert "Authorization" in r.headers["Vary"] + + # Matching ETag + matching Last-Modified → 304 + r = normal_user_client.get( + endpoint, + headers={"If-None-Match": etag, "If-Modified-Since": last_modified}, + ) + assert r.status_code == HTTPStatus.NOT_MODIFIED, r.text + assert not r.text + + # Wrong ETag only → 200 + r = normal_user_client.get( + endpoint, + headers={"If-None-Match": "wrongEtag"}, + ) + assert r.status_code == HTTPStatus.OK, r.json() + assert r.json(), r.text + + # Past ETag + past timestamp → 200 + r = normal_user_client.get( + endpoint, + headers={ + "If-None-Match": "pastEtag", + "If-Modified-Since": "Mon, 1 Apr 2000 00:42:42 GMT", + }, + ) + assert r.status_code == HTTPStatus.OK, r.json() + assert r.json(), r.text + + # Wrong ETag + future timestamp → 304 (If-Modified-Since takes effect) + r = normal_user_client.get( + endpoint, + headers={ + "If-None-Match": "futureEtag", + "If-Modified-Since": "Mon, 1 Apr 9999 00:42:42 GMT", + }, + ) + assert r.status_code == HTTPStatus.NOT_MODIFIED, r.text + assert not r.text + + # Wrong ETag + invalid timestamp → 200 + r = normal_user_client.get( + endpoint, + headers={ + "If-None-Match": "futureEtag", + "If-Modified-Since": "wrong format", + }, + ) + assert r.status_code == HTTPStatus.OK, r.json() + assert r.json(), r.text + + # Correct ETag + past timestamp → 304 (ETag match takes priority) + r = normal_user_client.get( + endpoint, + headers={ + "If-None-Match": etag, + "If-Modified-Since": "Mon, 1 Apr 2000 00:42:42 GMT", + }, + ) + assert r.status_code == HTTPStatus.NOT_MODIFIED, r.text + assert not r.text + + # Correct ETag + future timestamp → 304 + r = normal_user_client.get( + endpoint, + headers={ + "If-None-Match": etag, + "If-Modified-Since": "Mon, 1 Apr 9999 00:42:42 GMT", + }, + ) + assert r.status_code == HTTPStatus.NOT_MODIFIED, r.text + assert not r.text + + # Correct ETag + invalid timestamp → 304 + r = normal_user_client.get( + endpoint, + headers={ + "If-None-Match": etag, + "If-Modified-Since": "wrong format", + }, + ) + assert r.status_code == HTTPStatus.NOT_MODIFIED, r.text + assert not r.text + + +def test_vo_filtering(normal_user_client): + """Users only see "all" entries plus those of their own VO.""" + r = normal_user_client.get("/api/rss/storage") + assert r.status_code == HTTPStatus.OK, r.json() + assert set(r.json()) == {"SE-CERN"} # not SE-OTHER (vo=other_vo) + + r = normal_user_client.get("/api/rss/site") + assert r.status_code == HTTPStatus.OK, r.json() + assert set(r.json()) == {"LCG.CERN.cern", "LCG.Shared.ch"} + + +@pytest.mark.parametrize("endpoint", ALL_ENDPOINTS) +def test_empty_db(empty_db_client, endpoint): + """An empty database yields an empty result with valid cache headers.""" + r = empty_db_client.get(endpoint) + assert r.status_code == HTTPStatus.OK, r.text + assert r.json() == {} + assert r.headers["ETag"] == "empty-0-lhcb" + + # A conditional request against the sentinel revision still works + r = empty_db_client.get(endpoint, headers={"If-None-Match": "empty-0-lhcb"}) + assert r.status_code == HTTPStatus.NOT_MODIFIED, r.text + + +def test_served_from_cache(normal_user_client, monkeypatch): + """Once populated, requests are served from the cache without DB access.""" + from diracx.db.sql.rss.db import ResourceStatusDB + + # Populate the cache for every endpoint + for endpoint in ALL_ENDPOINTS: + r = normal_user_client.get(endpoint) + assert r.status_code == HTTPStatus.OK, r.text + + # Break every read path of the DB to prove the cache is used + async def _fail(*args, **kwargs): + raise AssertionError("The database should not be accessed") + + for method in ( + "get_site_statuses", + "get_resource_statuses", + "get_resource_status_date", + "get_site_status_date", + ): + monkeypatch.setattr(ResourceStatusDB, method, _fail) + + for endpoint in ALL_ENDPOINTS: + r = normal_user_client.get(endpoint) + assert r.status_code == HTTPStatus.OK, r.text + assert r.json(), r.text diff --git a/diracx-testing/src/diracx/testing/utils.py b/diracx-testing/src/diracx/testing/utils.py index 570241e69..e700d5fa3 100644 --- a/diracx-testing/src/diracx/testing/utils.py +++ b/diracx-testing/src/diracx/testing/utils.py @@ -168,7 +168,7 @@ def __init__( test_sandbox_settings, test_dev_settings, ): - from diracx.core.config import ConfigSource + from diracx.core.config import AsyncCacheableSource, ConfigSource from diracx.core.extensions import select_from_extension from diracx.core.settings import ServiceSettingsBase from diracx.db.os.utils import BaseOSDB @@ -252,6 +252,7 @@ def enrich_tokens( BaseSQLDB, BaseOSDB, ConfigSource, + AsyncCacheableSource, BaseAccessPolicy, ), ), obj diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/_client.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/_client.py index b19d47c68..04db73e5f 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/_client.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/_client.py @@ -21,11 +21,12 @@ JobsOperations, LollygagOperations, MyOperations, + RssOperations, WellKnownOperations, ) -class Dirac: # pylint: disable=client-accepts-api-version-keyword +class Dirac: # pylint: disable=client-accepts-api-version-keyword,too-many-instance-attributes """Dirac. :ivar well_known: WellKnownOperations operations @@ -40,6 +41,8 @@ class Dirac: # pylint: disable=client-accepts-api-version-keyword :vartype lollygag: _generated.operations.LollygagOperations :ivar my: MyOperations operations :vartype my: _generated.operations.MyOperations + :ivar rss: RssOperations operations + :vartype rss: _generated.operations.RssOperations :keyword endpoint: Service URL. Required. Default value is "". :paramtype endpoint: str """ @@ -78,6 +81,7 @@ def __init__( # pylint: disable=missing-client-constructor-parameter-credential self.jobs = JobsOperations(self._client, self._config, self._serialize, self._deserialize) self.lollygag = LollygagOperations(self._client, self._config, self._serialize, self._deserialize) self.my = MyOperations(self._client, self._config, self._serialize, self._deserialize) + self.rss = RssOperations(self._client, self._config, self._serialize, self._deserialize) def send_request(self, request: HttpRequest, *, stream: bool = False, **kwargs: Any) -> HttpResponse: """Runs the network request through the client's chained policies. diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/_client.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/_client.py index 32b9dad3a..1adec809d 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/_client.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/_client.py @@ -21,11 +21,12 @@ JobsOperations, LollygagOperations, MyOperations, + RssOperations, WellKnownOperations, ) -class Dirac: # pylint: disable=client-accepts-api-version-keyword +class Dirac: # pylint: disable=client-accepts-api-version-keyword,too-many-instance-attributes """Dirac. :ivar well_known: WellKnownOperations operations @@ -40,6 +41,8 @@ class Dirac: # pylint: disable=client-accepts-api-version-keyword :vartype lollygag: _generated.aio.operations.LollygagOperations :ivar my: MyOperations operations :vartype my: _generated.aio.operations.MyOperations + :ivar rss: RssOperations operations + :vartype rss: _generated.aio.operations.RssOperations :keyword endpoint: Service URL. Required. Default value is "". :paramtype endpoint: str """ @@ -78,6 +81,7 @@ def __init__( # pylint: disable=missing-client-constructor-parameter-credential self.jobs = JobsOperations(self._client, self._config, self._serialize, self._deserialize) self.lollygag = LollygagOperations(self._client, self._config, self._serialize, self._deserialize) self.my = MyOperations(self._client, self._config, self._serialize, self._deserialize) + self.rss = RssOperations(self._client, self._config, self._serialize, self._deserialize) def send_request( self, request: HttpRequest, *, stream: bool = False, **kwargs: Any diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/__init__.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/__init__.py index 5cfdf7253..d7d250107 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/__init__.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/__init__.py @@ -16,6 +16,7 @@ from ._operations import JobsOperations # type: ignore from ._operations import LollygagOperations # type: ignore from ._operations import MyOperations # type: ignore +from ._operations import RssOperations # type: ignore from ._patch import __all__ as _patch_all from ._patch import * @@ -28,6 +29,7 @@ "JobsOperations", "LollygagOperations", "MyOperations", + "RssOperations", ] __all__.extend([p for p in _patch_all if p not in __all__]) # pyright: ignore _patch_sdk() diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/_operations.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/_operations.py index a2e0565c5..fb42de5d0 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/_operations.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/_operations.py @@ -56,6 +56,10 @@ build_lollygag_insert_owner_object_request, build_my_pilots_get_pilot_summary_request, build_my_pilots_submit_pilot_request, + build_rss_get_compute_status_request, + build_rss_get_fts_status_request, + build_rss_get_site_status_request, + build_rss_get_storage_status_request, build_well_known_get_installation_metadata_request, build_well_known_get_jwks_request, build_well_known_get_openid_configuration_request, @@ -2605,3 +2609,303 @@ async def pilots_get_pilot_summary(self, **kwargs: Any) -> dict[str, int]: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized # type: ignore + + +class RssOperations: + """ + .. warning:: + **DO NOT** instantiate this class directly. + + Instead, you should access the following operations through + :class:`~_generated.aio.Dirac`'s + :attr:`rss` attribute. + """ + + models = _models + + def __init__(self, *args, **kwargs) -> None: + input_args = list(args) + self._client: AsyncPipelineClient = input_args.pop(0) if input_args else kwargs.pop("client") + self._config: DiracConfiguration = input_args.pop(0) if input_args else kwargs.pop("config") + self._serialize: Serializer = input_args.pop(0) if input_args else kwargs.pop("serializer") + self._deserialize: Deserializer = input_args.pop(0) if input_args else kwargs.pop("deserializer") + + @distributed_trace_async + async def get_storage_status( + self, + *, + if_modified_since: Optional[str] = None, + etag: Optional[str] = None, + match_condition: Optional[MatchConditions] = None, + **kwargs: Any + ) -> dict[str, _models.StorageElementStatus]: + """Get Storage Status. + + Get the latest status of storage elements, scoped to the caller's VO. + + :keyword if_modified_since: Default value is None. + :paramtype if_modified_since: str + :keyword etag: check if resource is changed. Set None to skip checking etag. Default value is + None. + :paramtype etag: str + :keyword match_condition: The match condition to use upon the etag. Default value is None. + :paramtype match_condition: ~azure.core.MatchConditions + :return: dict mapping str to StorageElementStatus + :rtype: dict[str, ~_generated.models.StorageElementStatus] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + if match_condition == MatchConditions.IfNotModified: + error_map[412] = ResourceModifiedError + elif match_condition == MatchConditions.IfPresent: + error_map[412] = ResourceNotFoundError + elif match_condition == MatchConditions.IfMissing: + error_map[412] = ResourceExistsError + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[dict[str, _models.StorageElementStatus]] = kwargs.pop("cls", None) + + _request = build_rss_get_storage_status_request( + if_modified_since=if_modified_since, + etag=etag, + match_condition=match_condition, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("{StorageElementStatus}", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @distributed_trace_async + async def get_compute_status( + self, + *, + if_modified_since: Optional[str] = None, + etag: Optional[str] = None, + match_condition: Optional[MatchConditions] = None, + **kwargs: Any + ) -> dict[str, _models.ComputeElementStatus]: + """Get Compute Status. + + Get the latest status of compute elements, scoped to the caller's VO. + + :keyword if_modified_since: Default value is None. + :paramtype if_modified_since: str + :keyword etag: check if resource is changed. Set None to skip checking etag. Default value is + None. + :paramtype etag: str + :keyword match_condition: The match condition to use upon the etag. Default value is None. + :paramtype match_condition: ~azure.core.MatchConditions + :return: dict mapping str to ComputeElementStatus + :rtype: dict[str, ~_generated.models.ComputeElementStatus] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + if match_condition == MatchConditions.IfNotModified: + error_map[412] = ResourceModifiedError + elif match_condition == MatchConditions.IfPresent: + error_map[412] = ResourceNotFoundError + elif match_condition == MatchConditions.IfMissing: + error_map[412] = ResourceExistsError + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[dict[str, _models.ComputeElementStatus]] = kwargs.pop("cls", None) + + _request = build_rss_get_compute_status_request( + if_modified_since=if_modified_since, + etag=etag, + match_condition=match_condition, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("{ComputeElementStatus}", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @distributed_trace_async + async def get_site_status( + self, + *, + if_modified_since: Optional[str] = None, + etag: Optional[str] = None, + match_condition: Optional[MatchConditions] = None, + **kwargs: Any + ) -> dict[str, _models.SiteStatus]: + """Get Site Status. + + Get the latest status of sites, scoped to the caller's VO. + + :keyword if_modified_since: Default value is None. + :paramtype if_modified_since: str + :keyword etag: check if resource is changed. Set None to skip checking etag. Default value is + None. + :paramtype etag: str + :keyword match_condition: The match condition to use upon the etag. Default value is None. + :paramtype match_condition: ~azure.core.MatchConditions + :return: dict mapping str to SiteStatus + :rtype: dict[str, ~_generated.models.SiteStatus] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + if match_condition == MatchConditions.IfNotModified: + error_map[412] = ResourceModifiedError + elif match_condition == MatchConditions.IfPresent: + error_map[412] = ResourceNotFoundError + elif match_condition == MatchConditions.IfMissing: + error_map[412] = ResourceExistsError + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[dict[str, _models.SiteStatus]] = kwargs.pop("cls", None) + + _request = build_rss_get_site_status_request( + if_modified_since=if_modified_since, + etag=etag, + match_condition=match_condition, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("{SiteStatus}", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @distributed_trace_async + async def get_fts_status( + self, + *, + if_modified_since: Optional[str] = None, + etag: Optional[str] = None, + match_condition: Optional[MatchConditions] = None, + **kwargs: Any + ) -> dict[str, _models.FTSStatus]: + """Get Fts Status. + + Get the latest status of FTS servers, scoped to the caller's VO. + + :keyword if_modified_since: Default value is None. + :paramtype if_modified_since: str + :keyword etag: check if resource is changed. Set None to skip checking etag. Default value is + None. + :paramtype etag: str + :keyword match_condition: The match condition to use upon the etag. Default value is None. + :paramtype match_condition: ~azure.core.MatchConditions + :return: dict mapping str to FTSStatus + :rtype: dict[str, ~_generated.models.FTSStatus] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + if match_condition == MatchConditions.IfNotModified: + error_map[412] = ResourceModifiedError + elif match_condition == MatchConditions.IfPresent: + error_map[412] = ResourceNotFoundError + elif match_condition == MatchConditions.IfMissing: + error_map[412] = ResourceExistsError + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[dict[str, _models.FTSStatus]] = kwargs.pop("cls", None) + + _request = build_rss_get_fts_status_request( + if_modified_since=if_modified_since, + etag=etag, + match_condition=match_condition, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("{FTSStatus}", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/__init__.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/__init__.py index b97d2e439..6684af567 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/__init__.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/__init__.py @@ -12,11 +12,17 @@ from ._models import ( # type: ignore + AllowedStatus, + BannedStatus, BodyAuthGetOidcToken, BodyAuthRevokeRefreshTokenByRefreshToken, BodyJobsRescheduleJobs, BodyJobsUnassignBulkJobsSandboxes, + ComputeElementStatus, + ComputeElementStatusAll, ExtendedMetadata, + FTSStatus, + FTSStatusAll, GroupInfo, HTTPValidationError, HeartbeatData, @@ -34,7 +40,14 @@ SearchParamsSearchItem, SetJobStatusReturn, SetJobStatusReturnSuccess, + SiteStatus, + SiteStatusAll, SortSpec, + StorageElementStatus, + StorageElementStatusCheck, + StorageElementStatusRead, + StorageElementStatusRemove, + StorageElementStatusWrite, SummaryParams, SummaryParamsSearchItem, SupportInfo, @@ -59,11 +72,17 @@ from diracx.client._generated.models._patch import patch_sdk as _patch_sdk __all__ = [ + "AllowedStatus", + "BannedStatus", "BodyAuthGetOidcToken", "BodyAuthRevokeRefreshTokenByRefreshToken", "BodyJobsRescheduleJobs", "BodyJobsUnassignBulkJobsSandboxes", + "ComputeElementStatus", + "ComputeElementStatusAll", "ExtendedMetadata", + "FTSStatus", + "FTSStatusAll", "GroupInfo", "HTTPValidationError", "HeartbeatData", @@ -81,7 +100,14 @@ "SearchParamsSearchItem", "SetJobStatusReturn", "SetJobStatusReturnSuccess", + "SiteStatus", + "SiteStatusAll", "SortSpec", + "StorageElementStatus", + "StorageElementStatusCheck", + "StorageElementStatusRead", + "StorageElementStatusRemove", + "StorageElementStatusWrite", "SummaryParams", "SummaryParamsSearchItem", "SupportInfo", diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py index 69b8ffcf1..6953a050f 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py @@ -16,6 +16,70 @@ JSON = MutableMapping[str, Any] +class AllowedStatus(_serialization.Model): + """AllowedStatus. + + All required parameters must be populated in order to send to server. + + :ivar allowed: Allowed. Required. + :vartype allowed: bool + :ivar warnings: Warnings. + :vartype warnings: str + """ + + _validation = { + "allowed": {"required": True}, + } + + _attribute_map = { + "allowed": {"key": "allowed", "type": "bool"}, + "warnings": {"key": "warnings", "type": "str"}, + } + + def __init__(self, *, allowed: bool, warnings: Optional[str] = None, **kwargs: Any) -> None: + """ + :keyword allowed: Allowed. Required. + :paramtype allowed: bool + :keyword warnings: Warnings. + :paramtype warnings: str + """ + super().__init__(**kwargs) + self.allowed = allowed + self.warnings = warnings + + +class BannedStatus(_serialization.Model): + """BannedStatus. + + All required parameters must be populated in order to send to server. + + :ivar allowed: Allowed. Required. + :vartype allowed: bool + :ivar reason: Reason. + :vartype reason: str + """ + + _validation = { + "allowed": {"required": True}, + } + + _attribute_map = { + "allowed": {"key": "allowed", "type": "bool"}, + "reason": {"key": "reason", "type": "str"}, + } + + def __init__(self, *, allowed: bool, reason: str = "Unknown", **kwargs: Any) -> None: + """ + :keyword allowed: Allowed. Required. + :paramtype allowed: bool + :keyword reason: Reason. + :paramtype reason: str + """ + super().__init__(**kwargs) + self.allowed = allowed + self.reason = reason + + class BodyAuthGetOidcToken(_serialization.Model): """Body_auth_get_oidc_token. @@ -184,6 +248,36 @@ def __init__(self, *, job_ids: list[int], **kwargs: Any) -> None: self.job_ids = job_ids +class ComputeElementStatus(_serialization.Model): + """ComputeElementStatus. + + All required parameters must be populated in order to send to server. + + :ivar all: All. Required. + :vartype all: ~_generated.models.ComputeElementStatusAll + """ + + _validation = { + "all": {"required": True}, + } + + _attribute_map = { + "all": {"key": "all", "type": "ComputeElementStatusAll"}, + } + + def __init__(self, *, all: "_models.ComputeElementStatusAll", **kwargs: Any) -> None: + """ + :keyword all: All. Required. + :paramtype all: ~_generated.models.ComputeElementStatusAll + """ + super().__init__(**kwargs) + self.all = all + + +class ComputeElementStatusAll(_serialization.Model): + """All.""" + + class ExtendedMetadata(_serialization.Model): """ExtendedMetadata. @@ -231,6 +325,36 @@ def __init__( self.gubbins_user_info = gubbins_user_info +class FTSStatus(_serialization.Model): + """FTSStatus. + + All required parameters must be populated in order to send to server. + + :ivar all: All. Required. + :vartype all: ~_generated.models.FTSStatusAll + """ + + _validation = { + "all": {"required": True}, + } + + _attribute_map = { + "all": {"key": "all", "type": "FTSStatusAll"}, + } + + def __init__(self, *, all: "_models.FTSStatusAll", **kwargs: Any) -> None: + """ + :keyword all: All. Required. + :paramtype all: ~_generated.models.FTSStatusAll + """ + super().__init__(**kwargs) + self.all = all + + +class FTSStatusAll(_serialization.Model): + """All.""" + + class GroupInfo(_serialization.Model): """GroupInfo. @@ -1282,6 +1406,36 @@ def __init__( self.last_update_time = last_update_time +class SiteStatus(_serialization.Model): + """SiteStatus. + + All required parameters must be populated in order to send to server. + + :ivar all: All. Required. + :vartype all: ~_generated.models.SiteStatusAll + """ + + _validation = { + "all": {"required": True}, + } + + _attribute_map = { + "all": {"key": "all", "type": "SiteStatusAll"}, + } + + def __init__(self, *, all: "_models.SiteStatusAll", **kwargs: Any) -> None: + """ + :keyword all: All. Required. + :paramtype all: ~_generated.models.SiteStatusAll + """ + super().__init__(**kwargs) + self.all = all + + +class SiteStatusAll(_serialization.Model): + """All.""" + + class SortSpec(_serialization.Model): """SortSpec. @@ -1315,6 +1469,77 @@ def __init__(self, *, parameter: str, direction: Union[str, "_models.SortDirecti self.direction = direction +class StorageElementStatus(_serialization.Model): + """StorageElementStatus. + + All required parameters must be populated in order to send to server. + + :ivar read: Read. Required. + :vartype read: ~_generated.models.StorageElementStatusRead + :ivar write: Write. Required. + :vartype write: ~_generated.models.StorageElementStatusWrite + :ivar check: Check. Required. + :vartype check: ~_generated.models.StorageElementStatusCheck + :ivar remove: Remove. Required. + :vartype remove: ~_generated.models.StorageElementStatusRemove + """ + + _validation = { + "read": {"required": True}, + "write": {"required": True}, + "check": {"required": True}, + "remove": {"required": True}, + } + + _attribute_map = { + "read": {"key": "read", "type": "StorageElementStatusRead"}, + "write": {"key": "write", "type": "StorageElementStatusWrite"}, + "check": {"key": "check", "type": "StorageElementStatusCheck"}, + "remove": {"key": "remove", "type": "StorageElementStatusRemove"}, + } + + def __init__( + self, + *, + read: "_models.StorageElementStatusRead", + write: "_models.StorageElementStatusWrite", + check: "_models.StorageElementStatusCheck", + remove: "_models.StorageElementStatusRemove", + **kwargs: Any + ) -> None: + """ + :keyword read: Read. Required. + :paramtype read: ~_generated.models.StorageElementStatusRead + :keyword write: Write. Required. + :paramtype write: ~_generated.models.StorageElementStatusWrite + :keyword check: Check. Required. + :paramtype check: ~_generated.models.StorageElementStatusCheck + :keyword remove: Remove. Required. + :paramtype remove: ~_generated.models.StorageElementStatusRemove + """ + super().__init__(**kwargs) + self.read = read + self.write = write + self.check = check + self.remove = remove + + +class StorageElementStatusCheck(_serialization.Model): + """Check.""" + + +class StorageElementStatusRead(_serialization.Model): + """Read.""" + + +class StorageElementStatusRemove(_serialization.Model): + """Remove.""" + + +class StorageElementStatusWrite(_serialization.Model): + """Write.""" + + class SummaryParams(_serialization.Model): """SummaryParams. diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/__init__.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/__init__.py index 5cfdf7253..d7d250107 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/__init__.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/__init__.py @@ -16,6 +16,7 @@ from ._operations import JobsOperations # type: ignore from ._operations import LollygagOperations # type: ignore from ._operations import MyOperations # type: ignore +from ._operations import RssOperations # type: ignore from ._patch import __all__ as _patch_all from ._patch import * @@ -28,6 +29,7 @@ "JobsOperations", "LollygagOperations", "MyOperations", + "RssOperations", ] __all__.extend([p for p in _patch_all if p not in __all__]) # pyright: ignore _patch_sdk() diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/_operations.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/_operations.py index 7dcaa92ee..7866a140a 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/_operations.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/_operations.py @@ -647,6 +647,118 @@ def build_my_pilots_get_pilot_summary_request(**kwargs: Any) -> HttpRequest: # return HttpRequest(method="GET", url=_url, headers=_headers, **kwargs) +def build_rss_get_storage_status_request( + *, + if_modified_since: Optional[str] = None, + etag: Optional[str] = None, + match_condition: Optional[MatchConditions] = None, + **kwargs: Any +) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/rss/storage" + + # Construct headers + if if_modified_since is not None: + _headers["if-modified-since"] = _SERIALIZER.header("if_modified_since", if_modified_since, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + if_match = prep_if_match(etag, match_condition) + if if_match is not None: + _headers["If-Match"] = _SERIALIZER.header("if_match", if_match, "str") + if_none_match = prep_if_none_match(etag, match_condition) + if if_none_match is not None: + _headers["If-None-Match"] = _SERIALIZER.header("if_none_match", if_none_match, "str") + + return HttpRequest(method="GET", url=_url, headers=_headers, **kwargs) + + +def build_rss_get_compute_status_request( + *, + if_modified_since: Optional[str] = None, + etag: Optional[str] = None, + match_condition: Optional[MatchConditions] = None, + **kwargs: Any +) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/rss/compute" + + # Construct headers + if if_modified_since is not None: + _headers["if-modified-since"] = _SERIALIZER.header("if_modified_since", if_modified_since, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + if_match = prep_if_match(etag, match_condition) + if if_match is not None: + _headers["If-Match"] = _SERIALIZER.header("if_match", if_match, "str") + if_none_match = prep_if_none_match(etag, match_condition) + if if_none_match is not None: + _headers["If-None-Match"] = _SERIALIZER.header("if_none_match", if_none_match, "str") + + return HttpRequest(method="GET", url=_url, headers=_headers, **kwargs) + + +def build_rss_get_site_status_request( + *, + if_modified_since: Optional[str] = None, + etag: Optional[str] = None, + match_condition: Optional[MatchConditions] = None, + **kwargs: Any +) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/rss/site" + + # Construct headers + if if_modified_since is not None: + _headers["if-modified-since"] = _SERIALIZER.header("if_modified_since", if_modified_since, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + if_match = prep_if_match(etag, match_condition) + if if_match is not None: + _headers["If-Match"] = _SERIALIZER.header("if_match", if_match, "str") + if_none_match = prep_if_none_match(etag, match_condition) + if if_none_match is not None: + _headers["If-None-Match"] = _SERIALIZER.header("if_none_match", if_none_match, "str") + + return HttpRequest(method="GET", url=_url, headers=_headers, **kwargs) + + +def build_rss_get_fts_status_request( + *, + if_modified_since: Optional[str] = None, + etag: Optional[str] = None, + match_condition: Optional[MatchConditions] = None, + **kwargs: Any +) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/rss/fts" + + # Construct headers + if if_modified_since is not None: + _headers["if-modified-since"] = _SERIALIZER.header("if_modified_since", if_modified_since, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + if_match = prep_if_match(etag, match_condition) + if if_match is not None: + _headers["If-Match"] = _SERIALIZER.header("if_match", if_match, "str") + if_none_match = prep_if_none_match(etag, match_condition) + if if_none_match is not None: + _headers["If-None-Match"] = _SERIALIZER.header("if_none_match", if_none_match, "str") + + return HttpRequest(method="GET", url=_url, headers=_headers, **kwargs) + + class WellKnownOperations: """ .. warning:: @@ -3181,3 +3293,303 @@ def pilots_get_pilot_summary(self, **kwargs: Any) -> dict[str, int]: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized # type: ignore + + +class RssOperations: + """ + .. warning:: + **DO NOT** instantiate this class directly. + + Instead, you should access the following operations through + :class:`~_generated.Dirac`'s + :attr:`rss` attribute. + """ + + models = _models + + def __init__(self, *args, **kwargs) -> None: + input_args = list(args) + self._client: PipelineClient = input_args.pop(0) if input_args else kwargs.pop("client") + self._config: DiracConfiguration = input_args.pop(0) if input_args else kwargs.pop("config") + self._serialize: Serializer = input_args.pop(0) if input_args else kwargs.pop("serializer") + self._deserialize: Deserializer = input_args.pop(0) if input_args else kwargs.pop("deserializer") + + @distributed_trace + def get_storage_status( + self, + *, + if_modified_since: Optional[str] = None, + etag: Optional[str] = None, + match_condition: Optional[MatchConditions] = None, + **kwargs: Any + ) -> dict[str, _models.StorageElementStatus]: + """Get Storage Status. + + Get the latest status of storage elements, scoped to the caller's VO. + + :keyword if_modified_since: Default value is None. + :paramtype if_modified_since: str + :keyword etag: check if resource is changed. Set None to skip checking etag. Default value is + None. + :paramtype etag: str + :keyword match_condition: The match condition to use upon the etag. Default value is None. + :paramtype match_condition: ~azure.core.MatchConditions + :return: dict mapping str to StorageElementStatus + :rtype: dict[str, ~_generated.models.StorageElementStatus] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + if match_condition == MatchConditions.IfNotModified: + error_map[412] = ResourceModifiedError + elif match_condition == MatchConditions.IfPresent: + error_map[412] = ResourceNotFoundError + elif match_condition == MatchConditions.IfMissing: + error_map[412] = ResourceExistsError + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[dict[str, _models.StorageElementStatus]] = kwargs.pop("cls", None) + + _request = build_rss_get_storage_status_request( + if_modified_since=if_modified_since, + etag=etag, + match_condition=match_condition, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("{StorageElementStatus}", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @distributed_trace + def get_compute_status( + self, + *, + if_modified_since: Optional[str] = None, + etag: Optional[str] = None, + match_condition: Optional[MatchConditions] = None, + **kwargs: Any + ) -> dict[str, _models.ComputeElementStatus]: + """Get Compute Status. + + Get the latest status of compute elements, scoped to the caller's VO. + + :keyword if_modified_since: Default value is None. + :paramtype if_modified_since: str + :keyword etag: check if resource is changed. Set None to skip checking etag. Default value is + None. + :paramtype etag: str + :keyword match_condition: The match condition to use upon the etag. Default value is None. + :paramtype match_condition: ~azure.core.MatchConditions + :return: dict mapping str to ComputeElementStatus + :rtype: dict[str, ~_generated.models.ComputeElementStatus] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + if match_condition == MatchConditions.IfNotModified: + error_map[412] = ResourceModifiedError + elif match_condition == MatchConditions.IfPresent: + error_map[412] = ResourceNotFoundError + elif match_condition == MatchConditions.IfMissing: + error_map[412] = ResourceExistsError + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[dict[str, _models.ComputeElementStatus]] = kwargs.pop("cls", None) + + _request = build_rss_get_compute_status_request( + if_modified_since=if_modified_since, + etag=etag, + match_condition=match_condition, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("{ComputeElementStatus}", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @distributed_trace + def get_site_status( + self, + *, + if_modified_since: Optional[str] = None, + etag: Optional[str] = None, + match_condition: Optional[MatchConditions] = None, + **kwargs: Any + ) -> dict[str, _models.SiteStatus]: + """Get Site Status. + + Get the latest status of sites, scoped to the caller's VO. + + :keyword if_modified_since: Default value is None. + :paramtype if_modified_since: str + :keyword etag: check if resource is changed. Set None to skip checking etag. Default value is + None. + :paramtype etag: str + :keyword match_condition: The match condition to use upon the etag. Default value is None. + :paramtype match_condition: ~azure.core.MatchConditions + :return: dict mapping str to SiteStatus + :rtype: dict[str, ~_generated.models.SiteStatus] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + if match_condition == MatchConditions.IfNotModified: + error_map[412] = ResourceModifiedError + elif match_condition == MatchConditions.IfPresent: + error_map[412] = ResourceNotFoundError + elif match_condition == MatchConditions.IfMissing: + error_map[412] = ResourceExistsError + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[dict[str, _models.SiteStatus]] = kwargs.pop("cls", None) + + _request = build_rss_get_site_status_request( + if_modified_since=if_modified_since, + etag=etag, + match_condition=match_condition, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("{SiteStatus}", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @distributed_trace + def get_fts_status( + self, + *, + if_modified_since: Optional[str] = None, + etag: Optional[str] = None, + match_condition: Optional[MatchConditions] = None, + **kwargs: Any + ) -> dict[str, _models.FTSStatus]: + """Get Fts Status. + + Get the latest status of FTS servers, scoped to the caller's VO. + + :keyword if_modified_since: Default value is None. + :paramtype if_modified_since: str + :keyword etag: check if resource is changed. Set None to skip checking etag. Default value is + None. + :paramtype etag: str + :keyword match_condition: The match condition to use upon the etag. Default value is None. + :paramtype match_condition: ~azure.core.MatchConditions + :return: dict mapping str to FTSStatus + :rtype: dict[str, ~_generated.models.FTSStatus] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + if match_condition == MatchConditions.IfNotModified: + error_map[412] = ResourceModifiedError + elif match_condition == MatchConditions.IfPresent: + error_map[412] = ResourceNotFoundError + elif match_condition == MatchConditions.IfMissing: + error_map[412] = ResourceExistsError + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[dict[str, _models.FTSStatus]] = kwargs.pop("cls", None) + + _request = build_rss_get_fts_status_request( + if_modified_since=if_modified_since, + etag=etag, + match_condition=match_condition, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("{FTSStatus}", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore