diff --git a/api/v1/proteinconformationprediction_types.go b/api/v1/proteinconformationprediction_types.go index 78b4822..337b8a6 100644 --- a/api/v1/proteinconformationprediction_types.go +++ b/api/v1/proteinconformationprediction_types.go @@ -6,13 +6,20 @@ import ( ) type ProteinConformationPredictionProtein struct { + // +kubebuilder:validation:Required + // +kubebuilder:validation:MinLength=1 + // +kubebuilder:validation:MaxLength=10000 // +kubebuilder:validation:XValidation:rule="self == oldSelf",message="Value is immutable" Sequence string `json:"sequence"` + // +kubebuilder:validation:Required + // +kubebuilder:validation:MinItems=1 // +kubebuilder:validation:XValidation:rule="self == oldSelf",message="Value is immutable" ID []string `json:"id"` } type ProteinConformationPredictionModelWeights struct { + // +kubebuilder:validation:Required + // +kubebuilder:validation:MaxLength=2048 HTTP string `json:"http"` } @@ -31,7 +38,12 @@ type ProteinConformationPredictionModel struct { } type ProteinConformationPredictionDestinationS3 struct { + // +kubebuilder:validation:Required + // +kubebuilder:validation:MinLength=3 + // +kubebuilder:validation:MaxLength=63 Bucket string `json:"bucket"` + // +kubebuilder:validation:Required + // +kubebuilder:validation:MinLength=1 Region string `json:"region"` } @@ -59,14 +71,19 @@ type ProteinConformationPredictionJob struct { } type ProteinConformationPredictionSpec struct { + // +kubebuilder:validation:Required // +kubebuilder:validation:XValidation:rule="self == oldSelf",message="Value is immutable" - Protein ProteinConformationPredictionProtein `json:"protein"` - Model ProteinConformationPredictionModel `json:"model,omitempty"` + Protein ProteinConformationPredictionProtein `json:"protein"` + // +kubebuilder:validation:Required + Model ProteinConformationPredictionModel `json:"model"` + // +kubebuilder:validation:Required Destination ProteinConformationPredictionDestination `json:"destination"` Notifications ProteinConformationPredictionNotifications `json:"notify,omitempty"` Job ProteinConformationPredictionJob `json:"job,omitempty"` - Database string `json:"database"` - StorageClass string `json:"storageClass,omitempty"` + // +kubebuilder:validation:Required + // +kubebuilder:validation:MinLength=1 + Database string `json:"database"` + StorageClass string `json:"storageClass,omitempty"` } type ProteinConformationPredictionStatusPhase string @@ -84,13 +101,24 @@ type ProteinConformationPredictionStatus struct { Phase ProteinConformationPredictionStatusPhase `json:"phase,omitempty"` SequencePrefix string `json:"sequencePrefix,omitempty"` Error string `json:"error,omitempty"` - RetryCount int32 `json:"retryCount,omitempty"` + // Deprecated: use SearchRetryCount, PredictRetryCount or UploadRetryCount. + RetryCount int32 `json:"retryCount,omitempty"` + SearchRetryCount int32 `json:"searchRetryCount,omitempty"` + PredictRetryCount int32 `json:"predictRetryCount,omitempty"` + UploadRetryCount int32 `json:"uploadRetryCount,omitempty"` + LastTransitionTime *metav1.Time `json:"lastTransitionTime,omitempty"` + // +listType=map + // +listMapKey=type + // +patchMergeKey=type + // +patchStrategy=merge + Conditions []metav1.Condition `json:"conditions,omitempty" patchStrategy:"merge" patchMergeKey:"type"` } // +kubebuilder:object:root=true // +kubebuilder:subresource:status // +kubebuilder:printcolumn:name="Phase",type=string,JSONPath=`.status.phase` // +kubebuilder:printcolumn:name="Sequence",type=string,JSONPath=`.status.sequencePrefix` +// +kubebuilder:printcolumn:name="Ready",type=string,JSONPath=`.status.conditions[?(@.type=='Ready')].status` // +kubebuilder:printcolumn:name="Age",type="date",JSONPath=".metadata.creationTimestamp" type ProteinConformationPrediction struct { diff --git a/api/v1/zz_generated.deepcopy.go b/api/v1/zz_generated.deepcopy.go index 5cb54ef..cd446e4 100644 --- a/api/v1/zz_generated.deepcopy.go +++ b/api/v1/zz_generated.deepcopy.go @@ -31,7 +31,7 @@ func (in *ProteinConformationPrediction) DeepCopyInto(out *ProteinConformationPr out.TypeMeta = in.TypeMeta in.ObjectMeta.DeepCopyInto(&out.ObjectMeta) in.Spec.DeepCopyInto(&out.Spec) - out.Status = in.Status + in.Status.DeepCopyInto(&out.Status) } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new ProteinConformationPrediction. @@ -257,6 +257,17 @@ func (in *ProteinConformationPredictionSpec) DeepCopy() *ProteinConformationPred // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *ProteinConformationPredictionStatus) DeepCopyInto(out *ProteinConformationPredictionStatus) { *out = *in + if in.LastTransitionTime != nil { + in, out := &in.LastTransitionTime, &out.LastTransitionTime + *out = (*in).DeepCopy() + } + if in.Conditions != nil { + in, out := &in.Conditions, &out.Conditions + *out = make([]metav1.Condition, len(*in)) + for i := range *in { + (*in)[i].DeepCopyInto(&(*out)[i]) + } + } } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new ProteinConformationPredictionStatus. diff --git a/cmd/main.go b/cmd/main.go index 1d1f91a..ed93724 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -39,8 +39,9 @@ import ( "sigs.k8s.io/controller-runtime/pkg/webhook" datav1 "github.com/kubefold/operator/api/v1" - "github.com/kubefold/operator/internal/controller" + "github.com/kubefold/operator/internal/database" "github.com/kubefold/operator/internal/observer" + "github.com/kubefold/operator/internal/prediction" // +kubebuilder:scaffold:imports ) @@ -221,17 +222,21 @@ func main() { os.Exit(1) } - if err = (&controller.ProteinDatabaseReconciler{ - Client: mgr.GetClient(), - Scheme: mgr.GetScheme(), - }).SetupWithManager(mgr); err != nil { + databaseReconciler := database.NewReconciler( + mgr.GetClient(), + mgr.GetScheme(), + mgr.GetEventRecorderFor("proteindatabase-controller"), + ) + if err = databaseReconciler.SetupWithManager(mgr); err != nil { setupLog.Error(err, "unable to create controller", "controller", "ProteinDatabase") os.Exit(1) } - if err = (&controller.ProteinConformationPredictionReconciler{ - Client: mgr.GetClient(), - Scheme: mgr.GetScheme(), - }).SetupWithManager(mgr); err != nil { + predictionReconciler := prediction.NewReconciler( + mgr.GetClient(), + mgr.GetScheme(), + mgr.GetEventRecorderFor("proteinconformationprediction-controller"), + ) + if err = predictionReconciler.SetupWithManager(mgr); err != nil { setupLog.Error(err, "unable to create controller", "controller", "ProteinConformationPrediction") os.Exit(1) } diff --git a/config/crd/bases/data.kubefold.io_proteinconformationpredictions.yaml b/config/crd/bases/data.kubefold.io_proteinconformationpredictions.yaml index e76025f..2ee9a33 100644 --- a/config/crd/bases/data.kubefold.io_proteinconformationpredictions.yaml +++ b/config/crd/bases/data.kubefold.io_proteinconformationpredictions.yaml @@ -21,6 +21,9 @@ spec: - jsonPath: .status.sequencePrefix name: Sequence type: string + - jsonPath: .status.conditions[?(@.type=='Ready')].status + name: Ready + type: string - jsonPath: .metadata.creationTimestamp name: Age type: date @@ -48,14 +51,18 @@ spec: spec: properties: database: + minLength: 1 type: string destination: properties: s3: properties: bucket: + maxLength: 63 + minLength: 3 type: string region: + minLength: 1 type: string required: - bucket @@ -317,6 +324,7 @@ spec: weights: properties: http: + maxLength: 2048 type: string required: - http @@ -340,11 +348,14 @@ spec: id: items: type: string + minItems: 1 type: array x-kubernetes-validations: - message: Value is immutable rule: self == oldSelf sequence: + maxLength: 10000 + minLength: 1 type: string x-kubernetes-validations: - message: Value is immutable @@ -361,19 +372,93 @@ spec: required: - database - destination + - model - protein type: object status: properties: + conditions: + items: + description: Condition contains details for one aspect of the current + state of this API Resource. + properties: + lastTransitionTime: + description: |- + lastTransitionTime is the last time the condition transitioned from one status to another. + This should be when the underlying condition changed. If that is not known, then using the time when the API field changed is acceptable. + format: date-time + type: string + message: + description: |- + message is a human readable message indicating details about the transition. + This may be an empty string. + maxLength: 32768 + type: string + observedGeneration: + description: |- + observedGeneration represents the .metadata.generation that the condition was set based upon. + For instance, if .metadata.generation is currently 12, but the .status.conditions[x].observedGeneration is 9, the condition is out of date + with respect to the current state of the instance. + format: int64 + minimum: 0 + type: integer + reason: + description: |- + reason contains a programmatic identifier indicating the reason for the condition's last transition. + Producers of specific condition types may define expected values and meanings for this field, + and whether the values are considered a guaranteed API. + The value should be a CamelCase string. + This field may not be empty. + maxLength: 1024 + minLength: 1 + pattern: ^[A-Za-z]([A-Za-z0-9_,:]*[A-Za-z0-9_])?$ + type: string + status: + description: status of the condition, one of True, False, Unknown. + enum: + - "True" + - "False" + - Unknown + type: string + type: + description: type of condition in CamelCase or in foo.example.com/CamelCase. + maxLength: 316 + pattern: ^([a-z0-9]([-a-z0-9]*[a-z0-9])?(\.[a-z0-9]([-a-z0-9]*[a-z0-9])?)*/)?(([A-Za-z0-9][-A-Za-z0-9_.]*)?[A-Za-z0-9])$ + type: string + required: + - lastTransitionTime + - message + - reason + - status + - type + type: object + type: array + x-kubernetes-list-map-keys: + - type + x-kubernetes-list-type: map error: type: string + lastTransitionTime: + format: date-time + type: string phase: type: string + predictRetryCount: + format: int32 + type: integer retryCount: + description: 'Deprecated: use SearchRetryCount, PredictRetryCount + or UploadRetryCount.' + format: int32 + type: integer + searchRetryCount: format: int32 type: integer sequencePrefix: type: string + uploadRetryCount: + format: int32 + type: integer type: object type: object served: true diff --git a/dist/install.yaml b/dist/install.yaml deleted file mode 100644 index 31b5bd5..0000000 --- a/dist/install.yaml +++ /dev/null @@ -1,1303 +0,0 @@ -apiVersion: v1 -kind: Namespace -metadata: - labels: - app.kubernetes.io/managed-by: kustomize - app.kubernetes.io/name: kubefold - control-plane: controller-manager - name: kubefold-system ---- -apiVersion: apiextensions.k8s.io/v1 -kind: CustomResourceDefinition -metadata: - annotations: - controller-gen.kubebuilder.io/version: v0.17.2 - name: proteinconformationpredictions.data.kubefold.io -spec: - group: data.kubefold.io - names: - kind: ProteinConformationPrediction - listKind: ProteinConformationPredictionList - plural: proteinconformationpredictions - singular: proteinconformationprediction - scope: Namespaced - versions: - - additionalPrinterColumns: - - jsonPath: .status.phase - name: Phase - type: string - - jsonPath: .status.sequencePrefix - name: Sequence - type: string - - jsonPath: .metadata.creationTimestamp - name: Age - type: date - name: v1 - schema: - openAPIV3Schema: - properties: - apiVersion: - description: |- - APIVersion defines the versioned schema of this representation of an object. - Servers should convert recognized schemas to the latest internal value, and - may reject unrecognized values. - More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#resources - type: string - kind: - description: |- - Kind is a string value representing the REST resource this object represents. - Servers may infer this from the endpoint the client submits requests to. - Cannot be updated. - In CamelCase. - More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#types-kinds - type: string - metadata: - type: object - spec: - properties: - database: - type: string - destination: - properties: - s3: - properties: - bucket: - type: string - region: - type: string - required: - - bucket - - region - type: object - required: - - s3 - type: object - job: - properties: - predictionNodeSelector: - description: |- - A node selector represents the union of the results of one or more label queries - over a set of nodes; that is, it represents the OR of the selectors represented - by the node selector terms. - properties: - nodeSelectorTerms: - description: Required. A list of node selector terms. The - terms are ORed. - items: - description: |- - A null or empty node selector term matches no objects. The requirements of - them are ANDed. - The TopologySelectorTerm type implements a subset of the NodeSelectorTerm. - properties: - matchExpressions: - description: A list of node selector requirements by - node's labels. - items: - description: |- - A node selector requirement is a selector that contains values, a key, and an operator - that relates the key and values. - properties: - key: - description: The label key that the selector applies - to. - type: string - operator: - description: |- - Represents a key's relationship to a set of values. - Valid operators are In, NotIn, Exists, DoesNotExist. Gt, and Lt. - type: string - values: - description: |- - An array of string values. If the operator is In or NotIn, - the values array must be non-empty. If the operator is Exists or DoesNotExist, - the values array must be empty. If the operator is Gt or Lt, the values - array must have a single element, which will be interpreted as an integer. - This array is replaced during a strategic merge patch. - items: - type: string - type: array - x-kubernetes-list-type: atomic - required: - - key - - operator - type: object - type: array - x-kubernetes-list-type: atomic - matchFields: - description: A list of node selector requirements by - node's fields. - items: - description: |- - A node selector requirement is a selector that contains values, a key, and an operator - that relates the key and values. - properties: - key: - description: The label key that the selector applies - to. - type: string - operator: - description: |- - Represents a key's relationship to a set of values. - Valid operators are In, NotIn, Exists, DoesNotExist. Gt, and Lt. - type: string - values: - description: |- - An array of string values. If the operator is In or NotIn, - the values array must be non-empty. If the operator is Exists or DoesNotExist, - the values array must be empty. If the operator is Gt or Lt, the values - array must have a single element, which will be interpreted as an integer. - This array is replaced during a strategic merge patch. - items: - type: string - type: array - x-kubernetes-list-type: atomic - required: - - key - - operator - type: object - type: array - x-kubernetes-list-type: atomic - type: object - x-kubernetes-map-type: atomic - type: array - x-kubernetes-list-type: atomic - required: - - nodeSelectorTerms - type: object - x-kubernetes-map-type: atomic - profile: - type: string - searchNodeSelector: - description: |- - A node selector represents the union of the results of one or more label queries - over a set of nodes; that is, it represents the OR of the selectors represented - by the node selector terms. - properties: - nodeSelectorTerms: - description: Required. A list of node selector terms. The - terms are ORed. - items: - description: |- - A null or empty node selector term matches no objects. The requirements of - them are ANDed. - The TopologySelectorTerm type implements a subset of the NodeSelectorTerm. - properties: - matchExpressions: - description: A list of node selector requirements by - node's labels. - items: - description: |- - A node selector requirement is a selector that contains values, a key, and an operator - that relates the key and values. - properties: - key: - description: The label key that the selector applies - to. - type: string - operator: - description: |- - Represents a key's relationship to a set of values. - Valid operators are In, NotIn, Exists, DoesNotExist. Gt, and Lt. - type: string - values: - description: |- - An array of string values. If the operator is In or NotIn, - the values array must be non-empty. If the operator is Exists or DoesNotExist, - the values array must be empty. If the operator is Gt or Lt, the values - array must have a single element, which will be interpreted as an integer. - This array is replaced during a strategic merge patch. - items: - type: string - type: array - x-kubernetes-list-type: atomic - required: - - key - - operator - type: object - type: array - x-kubernetes-list-type: atomic - matchFields: - description: A list of node selector requirements by - node's fields. - items: - description: |- - A node selector requirement is a selector that contains values, a key, and an operator - that relates the key and values. - properties: - key: - description: The label key that the selector applies - to. - type: string - operator: - description: |- - Represents a key's relationship to a set of values. - Valid operators are In, NotIn, Exists, DoesNotExist. Gt, and Lt. - type: string - values: - description: |- - An array of string values. If the operator is In or NotIn, - the values array must be non-empty. If the operator is Exists or DoesNotExist, - the values array must be empty. If the operator is Gt or Lt, the values - array must have a single element, which will be interpreted as an integer. - This array is replaced during a strategic merge patch. - items: - type: string - type: array - x-kubernetes-list-type: atomic - required: - - key - - operator - type: object - type: array - x-kubernetes-list-type: atomic - type: object - x-kubernetes-map-type: atomic - type: array - x-kubernetes-list-type: atomic - required: - - nodeSelectorTerms - type: object - x-kubernetes-map-type: atomic - type: object - model: - properties: - seeds: - items: - type: integer - type: array - x-kubernetes-validations: - - message: Value is immutable - rule: self == oldSelf - volume: - properties: - selector: - description: |- - A label selector is a label query over a set of resources. The result of matchLabels and - matchExpressions are ANDed. An empty label selector matches all objects. A null - label selector matches no objects. - properties: - matchExpressions: - description: matchExpressions is a list of label selector - requirements. The requirements are ANDed. - items: - description: |- - A label selector requirement is a selector that contains values, a key, and an operator that - relates the key and values. - properties: - key: - description: key is the label key that the selector - applies to. - type: string - operator: - description: |- - operator represents a key's relationship to a set of values. - Valid operators are In, NotIn, Exists and DoesNotExist. - type: string - values: - description: |- - values is an array of string values. If the operator is In or NotIn, - the values array must be non-empty. If the operator is Exists or DoesNotExist, - the values array must be empty. This array is replaced during a strategic - merge patch. - items: - type: string - type: array - x-kubernetes-list-type: atomic - required: - - key - - operator - type: object - type: array - x-kubernetes-list-type: atomic - matchLabels: - additionalProperties: - type: string - description: |- - matchLabels is a map of {key,value} pairs. A single {key,value} in the matchLabels - map is equivalent to an element of matchExpressions, whose key field is "key", the - operator is "In", and the values array contains only "value". The requirements are ANDed. - type: object - type: object - x-kubernetes-map-type: atomic - storageClassName: - type: string - type: object - weights: - properties: - http: - type: string - required: - - http - type: object - required: - - weights - type: object - notify: - properties: - region: - type: string - sms: - items: - type: string - type: array - required: - - region - type: object - protein: - properties: - id: - items: - type: string - type: array - x-kubernetes-validations: - - message: Value is immutable - rule: self == oldSelf - sequence: - type: string - x-kubernetes-validations: - - message: Value is immutable - rule: self == oldSelf - required: - - id - - sequence - type: object - x-kubernetes-validations: - - message: Value is immutable - rule: self == oldSelf - storageClass: - type: string - required: - - database - - destination - - protein - type: object - status: - properties: - error: - type: string - phase: - type: string - retryCount: - format: int32 - type: integer - sequencePrefix: - type: string - type: object - type: object - served: true - storage: true - subresources: - status: {} ---- -apiVersion: apiextensions.k8s.io/v1 -kind: CustomResourceDefinition -metadata: - annotations: - controller-gen.kubebuilder.io/version: v0.17.2 - name: proteindatabases.data.kubefold.io -spec: - group: data.kubefold.io - names: - kind: ProteinDatabase - listKind: ProteinDatabaseList - plural: proteindatabases - singular: proteindatabase - scope: Namespaced - versions: - - additionalPrinterColumns: - - jsonPath: .status.downloadStatus - name: Status - type: string - - jsonPath: .status.progress - name: Progress - type: string - - jsonPath: .status.size - name: Size - type: string - - jsonPath: .status.totalSize - name: Total Size - type: string - - jsonPath: .status.downloadSpeed - name: Download Speed - type: string - - jsonPath: .status.volumeName - name: Volume - type: string - - jsonPath: .metadata.creationTimestamp - name: Age - type: date - name: v1 - schema: - openAPIV3Schema: - properties: - apiVersion: - description: |- - APIVersion defines the versioned schema of this representation of an object. - Servers should convert recognized schemas to the latest internal value, and - may reject unrecognized values. - More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#resources - type: string - kind: - description: |- - Kind is a string value representing the REST resource this object represents. - Servers may infer this from the endpoint the client submits requests to. - Cannot be updated. - In CamelCase. - More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#types-kinds - type: string - metadata: - type: object - spec: - properties: - datasets: - properties: - bfd: - type: boolean - x-kubernetes-validations: - - message: Dataset can only be enabled. Deletion of dataset is - not supported yet - rule: self == oldSelf || (self == true && oldSelf == false) - mgyclusters: - type: boolean - x-kubernetes-validations: - - message: Dataset can only be enabled. Deletion of dataset is - not supported yet - rule: self == oldSelf || (self == true && oldSelf == false) - nt: - type: boolean - x-kubernetes-validations: - - message: Dataset can only be enabled. Deletion of dataset is - not supported yet - rule: self == oldSelf || (self == true && oldSelf == false) - pdb: - type: boolean - x-kubernetes-validations: - - message: Dataset can only be enabled. Deletion of dataset is - not supported yet - rule: self == oldSelf || (self == true && oldSelf == false) - pdbseqreq: - type: boolean - x-kubernetes-validations: - - message: Dataset can only be enabled. Deletion of dataset is - not supported yet - rule: self == oldSelf || (self == true && oldSelf == false) - rfam: - type: boolean - x-kubernetes-validations: - - message: Dataset can only be enabled. Deletion of dataset is - not supported yet - rule: self == oldSelf || (self == true && oldSelf == false) - rnacentral: - type: boolean - x-kubernetes-validations: - - message: Dataset can only be enabled. Deletion of dataset is - not supported yet - rule: self == oldSelf || (self == true && oldSelf == false) - uniprot: - type: boolean - x-kubernetes-validations: - - message: Dataset can only be enabled. Deletion of dataset is - not supported yet - rule: self == oldSelf || (self == true && oldSelf == false) - uniref90: - type: boolean - x-kubernetes-validations: - - message: Dataset can only be enabled. Deletion of dataset is - not supported yet - rule: self == oldSelf || (self == true && oldSelf == false) - required: - - bfd - - mgyclusters - - nt - - pdb - - pdbseqreq - - rfam - - rnacentral - - uniprot - - uniref90 - type: object - volume: - properties: - annotations: - additionalProperties: - type: string - type: object - labels: - additionalProperties: - type: string - type: object - selector: - description: |- - A label selector is a label query over a set of resources. The result of matchLabels and - matchExpressions are ANDed. An empty label selector matches all objects. A null - label selector matches no objects. - properties: - matchExpressions: - description: matchExpressions is a list of label selector - requirements. The requirements are ANDed. - items: - description: |- - A label selector requirement is a selector that contains values, a key, and an operator that - relates the key and values. - properties: - key: - description: key is the label key that the selector - applies to. - type: string - operator: - description: |- - operator represents a key's relationship to a set of values. - Valid operators are In, NotIn, Exists and DoesNotExist. - type: string - values: - description: |- - values is an array of string values. If the operator is In or NotIn, - the values array must be non-empty. If the operator is Exists or DoesNotExist, - the values array must be empty. This array is replaced during a strategic - merge patch. - items: - type: string - type: array - x-kubernetes-list-type: atomic - required: - - key - - operator - type: object - type: array - x-kubernetes-list-type: atomic - matchLabels: - additionalProperties: - type: string - description: |- - matchLabels is a map of {key,value} pairs. A single {key,value} in the matchLabels - map is equivalent to an element of matchExpressions, whose key field is "key", the - operator is "In", and the values array contains only "value". The requirements are ANDed. - type: object - type: object - x-kubernetes-map-type: atomic - storageClassName: - type: string - type: object - type: object - status: - properties: - datasets: - properties: - bfd: - properties: - delta: - format: int64 - type: integer - deltaDuration: - type: string - downloadSpeed: - type: string - downloadStatus: - type: string - lastUpdate: - format: date-time - type: string - progress: - type: string - size: - format: int64 - type: integer - totalSize: - format: int64 - type: integer - type: object - mgyclusters: - properties: - delta: - format: int64 - type: integer - deltaDuration: - type: string - downloadSpeed: - type: string - downloadStatus: - type: string - lastUpdate: - format: date-time - type: string - progress: - type: string - size: - format: int64 - type: integer - totalSize: - format: int64 - type: integer - type: object - nt: - properties: - delta: - format: int64 - type: integer - deltaDuration: - type: string - downloadSpeed: - type: string - downloadStatus: - type: string - lastUpdate: - format: date-time - type: string - progress: - type: string - size: - format: int64 - type: integer - totalSize: - format: int64 - type: integer - type: object - pdb: - properties: - delta: - format: int64 - type: integer - deltaDuration: - type: string - downloadSpeed: - type: string - downloadStatus: - type: string - lastUpdate: - format: date-time - type: string - progress: - type: string - size: - format: int64 - type: integer - totalSize: - format: int64 - type: integer - type: object - pdbseqreq: - properties: - delta: - format: int64 - type: integer - deltaDuration: - type: string - downloadSpeed: - type: string - downloadStatus: - type: string - lastUpdate: - format: date-time - type: string - progress: - type: string - size: - format: int64 - type: integer - totalSize: - format: int64 - type: integer - type: object - rfam: - properties: - delta: - format: int64 - type: integer - deltaDuration: - type: string - downloadSpeed: - type: string - downloadStatus: - type: string - lastUpdate: - format: date-time - type: string - progress: - type: string - size: - format: int64 - type: integer - totalSize: - format: int64 - type: integer - type: object - rnacentral: - properties: - delta: - format: int64 - type: integer - deltaDuration: - type: string - downloadSpeed: - type: string - downloadStatus: - type: string - lastUpdate: - format: date-time - type: string - progress: - type: string - size: - format: int64 - type: integer - totalSize: - format: int64 - type: integer - type: object - uniprot: - properties: - delta: - format: int64 - type: integer - deltaDuration: - type: string - downloadSpeed: - type: string - downloadStatus: - type: string - lastUpdate: - format: date-time - type: string - progress: - type: string - size: - format: int64 - type: integer - totalSize: - format: int64 - type: integer - type: object - uniref90: - properties: - delta: - format: int64 - type: integer - deltaDuration: - type: string - downloadSpeed: - type: string - downloadStatus: - type: string - lastUpdate: - format: date-time - type: string - progress: - type: string - size: - format: int64 - type: integer - totalSize: - format: int64 - type: integer - type: object - type: object - downloadSpeed: - type: string - downloadStatus: - type: string - lastUpdate: - format: date-time - type: string - progress: - type: string - size: - type: string - totalSize: - type: string - volumeName: - type: string - required: - - volumeName - type: object - type: object - served: true - storage: true - subresources: - status: {} ---- -apiVersion: v1 -kind: ServiceAccount -metadata: - labels: - app.kubernetes.io/managed-by: kustomize - app.kubernetes.io/name: kubefold - name: kubefold-controller-manager - namespace: kubefold-system ---- -apiVersion: rbac.authorization.k8s.io/v1 -kind: Role -metadata: - labels: - app.kubernetes.io/managed-by: kustomize - app.kubernetes.io/name: kubefold - name: kubefold-leader-election-role - namespace: kubefold-system -rules: -- apiGroups: - - "" - resources: - - configmaps - verbs: - - get - - list - - watch - - create - - update - - patch - - delete -- apiGroups: - - coordination.k8s.io - resources: - - leases - verbs: - - get - - list - - watch - - create - - update - - patch - - delete -- apiGroups: - - "" - resources: - - events - verbs: - - create - - patch ---- -apiVersion: rbac.authorization.k8s.io/v1 -kind: ClusterRole -metadata: - name: kubefold-manager-role -rules: -- apiGroups: - - "" - resources: - - persistentvolumeclaims - verbs: - - create - - delete - - get - - list - - patch - - update - - watch -- apiGroups: - - "" - resources: - - pods - verbs: - - get - - list - - watch -- apiGroups: - - "" - resources: - - pods/log - verbs: - - get -- apiGroups: - - batch - resources: - - jobs - verbs: - - create - - delete - - get - - list - - patch - - update - - watch -- apiGroups: - - data.kubefold.io - resources: - - proteinconformationpredictions - - proteindatabases - verbs: - - create - - delete - - get - - list - - patch - - update - - watch -- apiGroups: - - data.kubefold.io - resources: - - proteinconformationpredictions/finalizers - - proteindatabases/finalizers - verbs: - - update -- apiGroups: - - data.kubefold.io - resources: - - proteinconformationpredictions/status - - proteindatabases/status - verbs: - - get - - patch - - update ---- -apiVersion: rbac.authorization.k8s.io/v1 -kind: ClusterRole -metadata: - name: kubefold-metrics-auth-role -rules: -- apiGroups: - - authentication.k8s.io - resources: - - tokenreviews - verbs: - - create -- apiGroups: - - authorization.k8s.io - resources: - - subjectaccessreviews - verbs: - - create ---- -apiVersion: rbac.authorization.k8s.io/v1 -kind: ClusterRole -metadata: - name: kubefold-metrics-reader -rules: -- nonResourceURLs: - - /metrics - verbs: - - get ---- -apiVersion: rbac.authorization.k8s.io/v1 -kind: ClusterRole -metadata: - labels: - app.kubernetes.io/managed-by: kustomize - app.kubernetes.io/name: kubefold - name: kubefold-proteinconformationprediction-admin-role -rules: -- apiGroups: - - data.kubefold.io - resources: - - proteinconformationpredictions - verbs: - - '*' -- apiGroups: - - data.kubefold.io - resources: - - proteinconformationpredictions/status - verbs: - - get ---- -apiVersion: rbac.authorization.k8s.io/v1 -kind: ClusterRole -metadata: - labels: - app.kubernetes.io/managed-by: kustomize - app.kubernetes.io/name: kubefold - name: kubefold-proteinconformationprediction-editor-role -rules: -- apiGroups: - - data.kubefold.io - resources: - - proteinconformationpredictions - verbs: - - create - - delete - - get - - list - - patch - - update - - watch -- apiGroups: - - data.kubefold.io - resources: - - proteinconformationpredictions/status - verbs: - - get ---- -apiVersion: rbac.authorization.k8s.io/v1 -kind: ClusterRole -metadata: - labels: - app.kubernetes.io/managed-by: kustomize - app.kubernetes.io/name: kubefold - name: kubefold-proteinconformationprediction-viewer-role -rules: -- apiGroups: - - data.kubefold.io - resources: - - proteinconformationpredictions - verbs: - - get - - list - - watch -- apiGroups: - - data.kubefold.io - resources: - - proteinconformationpredictions/status - verbs: - - get ---- -apiVersion: rbac.authorization.k8s.io/v1 -kind: ClusterRole -metadata: - labels: - app.kubernetes.io/managed-by: kustomize - app.kubernetes.io/name: kubefold - name: kubefold-proteindatabase-admin-role -rules: -- apiGroups: - - data.kubefold.io - resources: - - proteindatabases - verbs: - - '*' -- apiGroups: - - data.kubefold.io - resources: - - proteindatabases/status - verbs: - - '*' -- apiGroups: - - core - resources: - - persistentvolumeclaims - verbs: - - '*' -- apiGroups: - - batch - resources: - - jobs - verbs: - - '*' -- apiGroups: - - core - resources: - - pods - verbs: - - '*' -- apiGroups: - - core - resources: - - pods/log - verbs: - - '*' ---- -apiVersion: rbac.authorization.k8s.io/v1 -kind: ClusterRole -metadata: - labels: - app.kubernetes.io/managed-by: kustomize - app.kubernetes.io/name: kubefold - name: kubefold-proteindatabase-editor-role -rules: -- apiGroups: - - data.kubefold.io - resources: - - proteindatabases - verbs: - - create - - delete - - get - - list - - patch - - update - - watch -- apiGroups: - - data.kubefold.io - resources: - - proteindatabases/status - verbs: - - get -- apiGroups: - - core - resources: - - persistentvolumeclaims - verbs: - - create - - delete - - get - - list - - patch - - update - - watch -- apiGroups: - - batch - resources: - - jobs - verbs: - - create - - delete - - get - - list - - patch - - update - - watch -- apiGroups: - - core - resources: - - pods - verbs: - - get - - list - - watch -- apiGroups: - - core - resources: - - pods/log - verbs: - - get ---- -apiVersion: rbac.authorization.k8s.io/v1 -kind: ClusterRole -metadata: - labels: - app.kubernetes.io/managed-by: kustomize - app.kubernetes.io/name: kubefold - name: kubefold-proteindatabase-viewer-role -rules: -- apiGroups: - - data.kubefold.io - resources: - - proteindatabases - verbs: - - get - - list - - watch -- apiGroups: - - data.kubefold.io - resources: - - proteindatabases/status - verbs: - - get ---- -apiVersion: rbac.authorization.k8s.io/v1 -kind: RoleBinding -metadata: - labels: - app.kubernetes.io/managed-by: kustomize - app.kubernetes.io/name: kubefold - name: kubefold-leader-election-rolebinding - namespace: kubefold-system -roleRef: - apiGroup: rbac.authorization.k8s.io - kind: Role - name: kubefold-leader-election-role -subjects: -- kind: ServiceAccount - name: kubefold-controller-manager - namespace: kubefold-system ---- -apiVersion: rbac.authorization.k8s.io/v1 -kind: ClusterRoleBinding -metadata: - labels: - app.kubernetes.io/managed-by: kustomize - app.kubernetes.io/name: kubefold - name: kubefold-manager-rolebinding -roleRef: - apiGroup: rbac.authorization.k8s.io - kind: ClusterRole - name: kubefold-manager-role -subjects: -- kind: ServiceAccount - name: kubefold-controller-manager - namespace: kubefold-system ---- -apiVersion: rbac.authorization.k8s.io/v1 -kind: ClusterRoleBinding -metadata: - name: kubefold-metrics-auth-rolebinding -roleRef: - apiGroup: rbac.authorization.k8s.io - kind: ClusterRole - name: kubefold-metrics-auth-role -subjects: -- kind: ServiceAccount - name: kubefold-controller-manager - namespace: kubefold-system ---- -apiVersion: v1 -kind: Service -metadata: - labels: - app.kubernetes.io/managed-by: kustomize - app.kubernetes.io/name: kubefold - control-plane: controller-manager - name: kubefold-controller-manager-metrics-service - namespace: kubefold-system -spec: - ports: - - name: https - port: 8443 - protocol: TCP - targetPort: 8443 - selector: - app.kubernetes.io/name: kubefold - control-plane: controller-manager ---- -apiVersion: apps/v1 -kind: Deployment -metadata: - labels: - app.kubernetes.io/managed-by: kustomize - app.kubernetes.io/name: kubefold - control-plane: controller-manager - name: kubefold-controller-manager - namespace: kubefold-system -spec: - replicas: 1 - selector: - matchLabels: - app.kubernetes.io/name: kubefold - control-plane: controller-manager - template: - metadata: - annotations: - kubectl.kubernetes.io/default-container: manager - labels: - app.kubernetes.io/name: kubefold - control-plane: controller-manager - spec: - containers: - - args: - - --metrics-bind-address=:8443 - - --leader-elect - - --health-probe-bind-address=:8081 - command: - - /manager - image: ghcr.io/kubefold/operator:latest - imagePullPolicy: Always - livenessProbe: - httpGet: - path: /healthz - port: 8081 - initialDelaySeconds: 15 - periodSeconds: 20 - name: manager - ports: [] - readinessProbe: - httpGet: - path: /readyz - port: 8081 - initialDelaySeconds: 5 - periodSeconds: 10 - resources: - limits: - cpu: 500m - memory: 128Mi - requests: - cpu: 10m - memory: 64Mi - securityContext: - allowPrivilegeEscalation: false - capabilities: - drop: - - ALL - volumeMounts: [] - securityContext: - runAsNonRoot: true - seccompProfile: - type: RuntimeDefault - serviceAccountName: kubefold-controller-manager - terminationGracePeriodSeconds: 10 - volumes: [] diff --git a/down.sh b/down.sh index fec8e2a..01012b7 100755 --- a/down.sh +++ b/down.sh @@ -1,5 +1,6 @@ #!/bin/bash +export AWS_PROFILE=solidchat export AWS_PAGER="" aws fsx describe-file-systems --region eu-central-1 --query "FileSystems[?FileSystemType=='LUSTRE'].FileSystemId" --output text | xargs -n1 -I {} aws fsx delete-file-system --region eu-central-1 --file-system-id {} eksctl delete cluster -f eks/cluster.yaml --disable-nodegroup-eviction \ No newline at end of file diff --git a/eks/cluster.yaml b/eks/cluster.yaml index 575baf1..a83415a 100644 --- a/eks/cluster.yaml +++ b/eks/cluster.yaml @@ -7,7 +7,9 @@ metadata: managedNodeGroups: - name: ng-primary instanceType: m5.xlarge - desiredCapacity: 1 + desiredCapacity: 3 + minSize: 1 + maxSize: 6 ssh: allow: true subnets: @@ -27,7 +29,9 @@ managedNodeGroups: - arn:aws:iam::aws:policy/AmazonSSMManagedInstanceCore - name: ng-gpu instanceType: g5.xlarge - desiredCapacity: 1 + desiredCapacity: 3 + minSize: 1 + maxSize: 6 ssh: allow: true subnets: diff --git a/internal/controller/constants.go b/internal/controller/constants.go deleted file mode 100644 index 5baff8b..0000000 --- a/internal/controller/constants.go +++ /dev/null @@ -1,23 +0,0 @@ -package controller - -import ( - "time" - - corev1 "k8s.io/api/core/v1" -) - -const ( - ProteinDatabaseFinalizer = "data.kubefold.io/finalizer" - PersistentVolumeClaimNameSuffix = "-data" - PersistentVolumeClaimSize = "1Gi" - ReconcileInterval = 10 * time.Second - - DownloaderImage = "ghcr.io/kubefold/downloader" - DownloaderImagePullPolicy = corev1.PullAlways - - ManagerImage = "ghcr.io/kubefold/manager" - ManagerImagePullPolicy = corev1.PullAlways - - AlphafoldImage = "public.ecr.aws/k3x1v3b7/alphafold3:latest" - AlphafoldImagePullPolicy = corev1.PullAlways -) diff --git a/internal/controller/proteinconformationprediction_controller.go b/internal/controller/proteinconformationprediction_controller.go deleted file mode 100644 index 9b1a5d8..0000000 --- a/internal/controller/proteinconformationprediction_controller.go +++ /dev/null @@ -1,1138 +0,0 @@ -package controller - -import ( - "context" - "encoding/base64" - "encoding/json" - "fmt" - "time" - - "github.com/kubefold/operator/internal/alphafold" - - batchv1 "k8s.io/api/batch/v1" - corev1 "k8s.io/api/core/v1" - "k8s.io/apimachinery/pkg/api/errors" - "k8s.io/apimachinery/pkg/api/resource" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/apimachinery/pkg/runtime" - "k8s.io/apimachinery/pkg/types" - "k8s.io/client-go/tools/record" - ctrl "sigs.k8s.io/controller-runtime" - "sigs.k8s.io/controller-runtime/pkg/client" - "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" - logf "sigs.k8s.io/controller-runtime/pkg/log" - - datav1 "github.com/kubefold/operator/api/v1" -) - -const ( - ProteinConformationPredictionFinalizer = "proteinconformationprediction.data.kubefold.io/finalizer" - DefaultStorageClass = "fsx-sc" - DefaultJobTimeout = 24 * time.Hour - MaxRetries = 3 -) - -type ProteinConformationPredictionReconciler struct { - client.Client - Scheme *runtime.Scheme - Recorder record.EventRecorder -} - -// +kubebuilder:rbac:groups=data.kubefold.io,resources=proteinconformationpredictions,verbs=get;list;watch;create;update;patch;delete -// +kubebuilder:rbac:groups=data.kubefold.io,resources=proteinconformationpredictions/status,verbs=get;update;patch -// +kubebuilder:rbac:groups=data.kubefold.io,resources=proteinconformationpredictions/finalizers,verbs=update -// +kubebuilder:rbac:groups=batch,resources=jobs,verbs=get;list;watch;create;update;patch;delete -// +kubebuilder:rbac:groups="",resources=persistentvolumeclaims,verbs=get;list;watch;create;update;patch;delete - -//nolint:gocyclo -func (r *ProteinConformationPredictionReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { - log := logf.FromContext(ctx) - - pred := &datav1.ProteinConformationPrediction{} - err := r.Get(ctx, req.NamespacedName, pred) - if err != nil { - if errors.IsNotFound(err) { - return ctrl.Result{}, nil - } - log.Error(err, "Failed to get ProteinConformationPrediction") - return ctrl.Result{}, err - } - - if !pred.DeletionTimestamp.IsZero() { - if controllerutil.ContainsFinalizer(pred, ProteinConformationPredictionFinalizer) { - if err := r.cleanupResources(ctx, pred); err != nil { - log.Error(err, "Failed to clean up resources") - return ctrl.Result{}, err - } - - controllerutil.RemoveFinalizer(pred, ProteinConformationPredictionFinalizer) - if err := r.Update(ctx, pred); err != nil { - log.Error(err, "Failed to remove finalizer") - return ctrl.Result{}, err - } - } - return ctrl.Result{}, nil - } - - if !controllerutil.ContainsFinalizer(pred, ProteinConformationPredictionFinalizer) { - controllerutil.AddFinalizer(pred, ProteinConformationPredictionFinalizer) - if err := r.Update(ctx, pred); err != nil { - log.Error(err, "Failed to add finalizer") - return ctrl.Result{}, err - } - return ctrl.Result{Requeue: true}, nil - } - - if err := r.validateSpec(pred); err != nil { - log.Error(err, "Invalid spec") - pred.Status.Phase = datav1.ProteinConformationPredictionStatusPhaseFailed - pred.Status.Error = err.Error() - if err := r.Status().Update(ctx, pred); err != nil { - log.Error(err, "Failed to update status") - return ctrl.Result{}, err - } - return ctrl.Result{}, nil - } - - if pred.Status.Phase != datav1.ProteinConformationPredictionStatusPhaseFailed && - pred.Status.Phase != datav1.ProteinConformationPredictionStatusPhaseCompleted { - - searchJobName := fmt.Sprintf("%s-search", pred.Name) - searchJob := &batchv1.Job{} - err := r.Get(ctx, types.NamespacedName{Name: searchJobName, Namespace: pred.Namespace}, searchJob) - if err == nil { - if searchJob.Status.Failed > 0 { - if pred.Status.RetryCount < MaxRetries { - pred.Status.RetryCount++ - log.Info("Search job failed, retrying", "Job", searchJobName, "RetryCount", pred.Status.RetryCount) - if err := r.Status().Update(ctx, pred); err != nil { - log.Error(err, "Failed to update retry count") - return ctrl.Result{}, err - } - if err := r.Delete(ctx, searchJob); err != nil { - log.Error(err, "Failed to delete failed search job") - return ctrl.Result{}, err - } - return ctrl.Result{Requeue: true}, nil - } - log.Info("Search job failed after max retries", "Job", searchJobName) - pred.Status.Phase = datav1.ProteinConformationPredictionStatusPhaseFailed - pred.Status.Error = "Search job failed after max retries" - if err := r.Status().Update(ctx, pred); err != nil { - log.Error(err, "Failed to update ProteinConformationPrediction status") - return ctrl.Result{}, err - } - return ctrl.Result{}, nil - } - } - - if pred.Status.Phase == datav1.ProteinConformationPredictionStatusPhasePredicting { - predJobName := fmt.Sprintf("%s-predict", pred.Name) - predJob := &batchv1.Job{} - err := r.Get(ctx, types.NamespacedName{Name: predJobName, Namespace: pred.Namespace}, predJob) - if err == nil { - if predJob.Status.Failed > 0 { - if pred.Status.RetryCount < MaxRetries { - pred.Status.RetryCount++ - log.Info("Prediction job failed, retrying", "Job", predJobName, "RetryCount", pred.Status.RetryCount) - if err := r.Status().Update(ctx, pred); err != nil { - log.Error(err, "Failed to update retry count") - return ctrl.Result{}, err - } - if err := r.Delete(ctx, predJob); err != nil { - log.Error(err, "Failed to delete failed prediction job") - return ctrl.Result{}, err - } - return ctrl.Result{Requeue: true}, nil - } - log.Info("Prediction job failed after max retries", "Job", predJobName) - pred.Status.Phase = datav1.ProteinConformationPredictionStatusPhaseFailed - pred.Status.Error = "Prediction job failed after max retries" - if err := r.Status().Update(ctx, pred); err != nil { - log.Error(err, "Failed to update ProteinConformationPrediction status") - return ctrl.Result{}, err - } - return ctrl.Result{}, nil - } - } - } - - if pred.Status.Phase == datav1.ProteinConformationPredictionStatusPhaseUploadingArtifacts { - uploadJobName := fmt.Sprintf("%s-upload", pred.Name) - uploadJob := &batchv1.Job{} - err := r.Get(ctx, types.NamespacedName{Name: uploadJobName, Namespace: pred.Namespace}, uploadJob) - if err == nil { - if uploadJob.Status.Failed > 0 { - if pred.Status.RetryCount < MaxRetries { - pred.Status.RetryCount++ - log.Info("Upload job failed, retrying", "Job", uploadJobName, "RetryCount", pred.Status.RetryCount) - if err := r.Status().Update(ctx, pred); err != nil { - log.Error(err, "Failed to update retry count") - return ctrl.Result{}, err - } - if err := r.Delete(ctx, uploadJob); err != nil { - log.Error(err, "Failed to delete failed upload job") - return ctrl.Result{}, err - } - return ctrl.Result{Requeue: true}, nil - } - log.Info("Upload job failed after max retries", "Job", uploadJobName) - pred.Status.Phase = datav1.ProteinConformationPredictionStatusPhaseFailed - pred.Status.Error = "Upload job failed after max retries" - if err := r.Status().Update(ctx, pred); err != nil { - log.Error(err, "Failed to update ProteinConformationPrediction status") - return ctrl.Result{}, err - } - return ctrl.Result{}, nil - } - } - } - } - - if pred.Status.Phase == "" { - pred.Status.Phase = datav1.ProteinConformationPredictionStatusPhaseNotStarted - pred.Status.SequencePrefix = pred.Spec.Protein.Sequence[:10] + "..." - pred.Status.RetryCount = 0 - if err := r.Status().Update(ctx, pred); err != nil { - log.Error(err, "Failed to update ProteinConformationPrediction status") - return ctrl.Result{}, err - } - return ctrl.Result{Requeue: true}, nil - } - - switch pred.Status.Phase { - case datav1.ProteinConformationPredictionStatusPhaseNotStarted: - return r.handleNotStarted(ctx, pred) - case datav1.ProteinConformationPredictionStatusPhaseAligning: - return r.handleAligning(ctx, pred) - case datav1.ProteinConformationPredictionStatusPhasePredicting: - return r.handlePredicting(ctx, pred) - case datav1.ProteinConformationPredictionStatusPhaseUploadingArtifacts: - return r.handleUploadingArtifacts(ctx, pred) - case datav1.ProteinConformationPredictionStatusPhaseCompleted, datav1.ProteinConformationPredictionStatusPhaseFailed: - return ctrl.Result{}, nil - default: - log.Info("Unknown phase", "Phase", pred.Status.Phase) - return ctrl.Result{}, nil - } -} - -func (r *ProteinConformationPredictionReconciler) handleNotStarted(ctx context.Context, pred *datav1.ProteinConformationPrediction) (ctrl.Result, error) { - log := logf.FromContext(ctx) - - proteinDB := &datav1.ProteinDatabase{} - err := r.Get(ctx, types.NamespacedName{Name: pred.Spec.Database, Namespace: pred.Namespace}, proteinDB) - if err != nil { - if errors.IsNotFound(err) { - log.Info("Waiting for ProteinDatabase to be created", "Database", pred.Spec.Database) - r.Recorder.Event(pred, corev1.EventTypeNormal, "DatabaseNotFound", fmt.Sprintf("Waiting for ProteinDatabase %s to be created", pred.Spec.Database)) - return ctrl.Result{RequeueAfter: time.Second * 10}, nil - } - log.Error(err, "Failed to get ProteinDatabase") - r.Recorder.Event(pred, corev1.EventTypeWarning, "DatabaseError", fmt.Sprintf("Failed to get ProteinDatabase: %v", err)) - return ctrl.Result{}, err - } - - pvcName := fmt.Sprintf("%s-data", pred.Name) - pvc := &corev1.PersistentVolumeClaim{} - err = r.Get(ctx, types.NamespacedName{Name: pvcName, Namespace: pred.Namespace}, pvc) - if err != nil { - if errors.IsNotFound(err) { - pvc = r.newPVC(pred, pvcName) - if err := controllerutil.SetControllerReference(pred, pvc, r.Scheme); err != nil { - log.Error(err, "Failed to set controller reference for PVC") - r.Recorder.Event(pred, corev1.EventTypeWarning, "PVCReferenceError", fmt.Sprintf("Failed to set controller reference for PVC: %v", err)) - return ctrl.Result{}, err - } - if err := r.Create(ctx, pvc); err != nil { - log.Error(err, "Failed to create PVC") - r.Recorder.Event(pred, corev1.EventTypeWarning, "PVCCreationError", fmt.Sprintf("Failed to create PVC: %v", err)) - return ctrl.Result{}, err - } - log.Info("Created PVC", "Name", pvcName) - r.Recorder.Event(pred, corev1.EventTypeNormal, "PVCCreated", fmt.Sprintf("Created PVC %s", pvcName)) - return ctrl.Result{Requeue: true}, nil - } - log.Error(err, "Failed to get PVC") - r.Recorder.Event(pred, corev1.EventTypeWarning, "PVCError", fmt.Sprintf("Failed to get PVC: %v", err)) - return ctrl.Result{}, err - } - - jobName := fmt.Sprintf("%s-search", pred.Name) - job := &batchv1.Job{} - err = r.Get(ctx, types.NamespacedName{Name: jobName, Namespace: pred.Namespace}, job) - if err != nil { - if errors.IsNotFound(err) { - encodedInput, err := r.prepareFoldInput(pred, false) - if err != nil { - log.Error(err, "Failed to prepare FoldInput") - r.Recorder.Event(pred, corev1.EventTypeWarning, "InputError", fmt.Sprintf("Failed to prepare FoldInput: %v", err)) - return ctrl.Result{}, err - } - - job = r.newSearchJob(pred, jobName, pvcName, encodedInput) - if err := controllerutil.SetControllerReference(pred, job, r.Scheme); err != nil { - log.Error(err, "Failed to set controller reference for search job") - r.Recorder.Event(pred, corev1.EventTypeWarning, "JobReferenceError", fmt.Sprintf("Failed to set controller reference for search job: %v", err)) - return ctrl.Result{}, err - } - if err := r.Create(ctx, job); err != nil { - log.Error(err, "Failed to create search job") - r.Recorder.Event(pred, corev1.EventTypeWarning, "JobCreationError", fmt.Sprintf("Failed to create search job: %v", err)) - return ctrl.Result{}, err - } - - pred.Status.Phase = datav1.ProteinConformationPredictionStatusPhaseAligning - if err := r.Status().Update(ctx, pred); err != nil { - log.Error(err, "Failed to update ProteinConformationPrediction status") - r.Recorder.Event(pred, corev1.EventTypeWarning, "StatusUpdateError", fmt.Sprintf("Failed to update status: %v", err)) - return ctrl.Result{}, err - } - - log.Info("Created search job and updated status", "Name", jobName) - r.Recorder.Event(pred, corev1.EventTypeNormal, "JobCreated", fmt.Sprintf("Created search job %s", jobName)) - return ctrl.Result{Requeue: true}, nil - } - log.Error(err, "Failed to get search job") - r.Recorder.Event(pred, corev1.EventTypeWarning, "JobError", fmt.Sprintf("Failed to get search job: %v", err)) - return ctrl.Result{}, err - } - - return ctrl.Result{}, nil -} - -func (r *ProteinConformationPredictionReconciler) handleAligning(ctx context.Context, pred *datav1.ProteinConformationPrediction) (ctrl.Result, error) { - log := logf.FromContext(ctx) - - jobName := fmt.Sprintf("%s-search", pred.Name) - job := &batchv1.Job{} - err := r.Get(ctx, types.NamespacedName{Name: jobName, Namespace: pred.Namespace}, job) - if err != nil { - log.Error(err, "Failed to get search job") - return ctrl.Result{}, err - } - - if r.checkJobTimeout(job) { - log.Info("Search job timed out", "Job", jobName) - pred.Status.Phase = datav1.ProteinConformationPredictionStatusPhaseFailed - if err := r.Status().Update(ctx, pred); err != nil { - log.Error(err, "Failed to update ProteinConformationPrediction status") - return ctrl.Result{}, err - } - return ctrl.Result{}, nil - } - - if job.Status.Succeeded > 0 { - pred.Status.Phase = datav1.ProteinConformationPredictionStatusPhasePredicting - if err := r.Status().Update(ctx, pred); err != nil { - log.Error(err, "Failed to update ProteinConformationPrediction status") - return ctrl.Result{}, err - } - log.Info("Search job completed, moving to prediction phase") - return ctrl.Result{Requeue: true}, nil - } - - if job.Status.Failed > 0 { - pred.Status.Phase = datav1.ProteinConformationPredictionStatusPhaseFailed - if err := r.Status().Update(ctx, pred); err != nil { - log.Error(err, "Failed to update ProteinConformationPrediction status") - return ctrl.Result{}, err - } - log.Info("Search job failed") - return ctrl.Result{}, nil - } - - return ctrl.Result{RequeueAfter: time.Second * 10}, nil -} - -func (r *ProteinConformationPredictionReconciler) handlePredicting(ctx context.Context, pred *datav1.ProteinConformationPrediction) (ctrl.Result, error) { - log := logf.FromContext(ctx) - - jobName := fmt.Sprintf("%s-predict", pred.Name) - job := &batchv1.Job{} - err := r.Get(ctx, types.NamespacedName{Name: jobName, Namespace: pred.Namespace}, job) - if err != nil { - if errors.IsNotFound(err) { - encodedInput, err := r.prepareFoldInput(pred, true) - if err != nil { - log.Error(err, "Failed to prepare FoldInput") - return ctrl.Result{}, err - } - - pvcName := fmt.Sprintf("%s-data", pred.Name) - job = r.newPredictionJob(pred, jobName, pvcName, encodedInput) - if err := controllerutil.SetControllerReference(pred, job, r.Scheme); err != nil { - log.Error(err, "Failed to set controller reference for prediction job") - return ctrl.Result{}, err - } - if err := r.Create(ctx, job); err != nil { - log.Error(err, "Failed to create prediction job") - return ctrl.Result{}, err - } - log.Info("Created prediction job", "Name", jobName) - return ctrl.Result{Requeue: true}, nil - } - log.Error(err, "Failed to get prediction job") - return ctrl.Result{}, err - } - - if r.checkJobTimeout(job) { - log.Info("Prediction job timed out", "Job", jobName) - pred.Status.Phase = datav1.ProteinConformationPredictionStatusPhaseFailed - if err := r.Status().Update(ctx, pred); err != nil { - log.Error(err, "Failed to update ProteinConformationPrediction status") - return ctrl.Result{}, err - } - return ctrl.Result{}, nil - } - - if job.Status.Succeeded > 0 { - pred.Status.Phase = datav1.ProteinConformationPredictionStatusPhaseUploadingArtifacts - if err := r.Status().Update(ctx, pred); err != nil { - log.Error(err, "Failed to update ProteinConformationPrediction status") - return ctrl.Result{}, err - } - log.Info("Prediction job completed, moving to uploading artifacts phase") - return ctrl.Result{Requeue: true}, nil - } - - if job.Status.Failed > 0 { - pred.Status.Phase = datav1.ProteinConformationPredictionStatusPhaseFailed - if err := r.Status().Update(ctx, pred); err != nil { - log.Error(err, "Failed to update ProteinConformationPrediction status") - return ctrl.Result{}, err - } - log.Info("Prediction job failed") - return ctrl.Result{}, nil - } - - return ctrl.Result{RequeueAfter: time.Second * 10}, nil -} - -func (r *ProteinConformationPredictionReconciler) handleUploadingArtifacts(ctx context.Context, pred *datav1.ProteinConformationPrediction) (ctrl.Result, error) { - log := logf.FromContext(ctx) - - jobName := fmt.Sprintf("%s-upload", pred.Name) - job := &batchv1.Job{} - err := r.Get(ctx, types.NamespacedName{Name: jobName, Namespace: pred.Namespace}, job) - if err != nil { - if errors.IsNotFound(err) { - pvcName := fmt.Sprintf("%s-data", pred.Name) - job = r.newUploadArtifactsJob(pred, jobName, pvcName) - if err := controllerutil.SetControllerReference(pred, job, r.Scheme); err != nil { - log.Error(err, "Failed to set controller reference for upload artifacts job") - return ctrl.Result{}, err - } - if err := r.Create(ctx, job); err != nil { - log.Error(err, "Failed to create upload artifacts job") - return ctrl.Result{}, err - } - log.Info("Created upload artifacts job", "Name", jobName) - return ctrl.Result{Requeue: true}, nil - } - log.Error(err, "Failed to get upload artifacts job") - return ctrl.Result{}, err - } - - if r.checkJobTimeout(job) { - log.Info("Upload job timed out", "Job", jobName) - pred.Status.Phase = datav1.ProteinConformationPredictionStatusPhaseFailed - if err := r.Status().Update(ctx, pred); err != nil { - log.Error(err, "Failed to update ProteinConformationPrediction status") - return ctrl.Result{}, err - } - return ctrl.Result{}, nil - } - - if job.Status.Succeeded > 0 { - pred.Status.Phase = datav1.ProteinConformationPredictionStatusPhaseCompleted - if err := r.Status().Update(ctx, pred); err != nil { - log.Error(err, "Failed to update ProteinConformationPrediction status") - return ctrl.Result{}, err - } - - if err := r.cleanupCompletedJobs(ctx, pred); err != nil { - log.Error(err, "Failed to cleanup completed jobs") - return ctrl.Result{}, err - } - - pvcName := fmt.Sprintf("%s-data", pred.Name) - pvc := &corev1.PersistentVolumeClaim{} - err := r.Get(ctx, types.NamespacedName{Name: pvcName, Namespace: pred.Namespace}, pvc) - if err == nil { - if err := r.Delete(ctx, pvc); err != nil { - log.Error(err, "Failed to delete PVC", "Name", pvcName) - return ctrl.Result{}, err - } - log.Info("Deleted PVC", "Name", pvcName) - } - - log.Info("Upload artifacts job completed, resource is now in completed state") - return ctrl.Result{}, nil - } - - if job.Status.Failed > 0 { - pred.Status.Phase = datav1.ProteinConformationPredictionStatusPhaseFailed - if err := r.Status().Update(ctx, pred); err != nil { - log.Error(err, "Failed to update ProteinConformationPrediction status") - return ctrl.Result{}, err - } - log.Info("Upload artifacts job failed") - return ctrl.Result{}, nil - } - - return ctrl.Result{RequeueAfter: time.Second * 10}, nil -} - -func (r *ProteinConformationPredictionReconciler) newPVC(pred *datav1.ProteinConformationPrediction, pvcName string) *corev1.PersistentVolumeClaim { - storageClass := DefaultStorageClass - if pred.Spec.Model.Volume.StorageClassName != nil && *pred.Spec.Model.Volume.StorageClassName != "" { - storageClass = *pred.Spec.Model.Volume.StorageClassName - } else if pred.Spec.StorageClass != "" { - storageClass = pred.Spec.StorageClass - } - - pvc := &corev1.PersistentVolumeClaim{ - ObjectMeta: metav1.ObjectMeta{ - Name: pvcName, - Namespace: pred.Namespace, - Labels: map[string]string{ - "app": pred.Name, - "data.kubefold.io/prediction": pred.Name, - "app.kubernetes.io/name": "proteinconformationprediction-data", - "app.kubernetes.io/instance": pred.Name, - "app.kubernetes.io/managed-by": "kubefold-operator", - }, - }, - Spec: corev1.PersistentVolumeClaimSpec{ - AccessModes: []corev1.PersistentVolumeAccessMode{corev1.ReadWriteOnce}, - Resources: corev1.VolumeResourceRequirements{ - Requests: corev1.ResourceList{ - corev1.ResourceStorage: resource.MustParse("10Gi"), - }, - }, - StorageClassName: &storageClass, - }, - } - - if pred.Spec.Model.Volume.Selector != nil { - pvc.Spec.Selector = pred.Spec.Model.Volume.Selector - } - - return pvc -} - -func (r *ProteinConformationPredictionReconciler) prepareFoldInput(pred *datav1.ProteinConformationPrediction, prediction bool) (string, error) { - input := alphafold.Input{ - Name: fmt.Sprintf("%s-%s", pred.Namespace, pred.Name), - Sequences: []alphafold.Sequence{ - { - Protein: alphafold.Protein{ - Sequence: pred.Spec.Protein.Sequence, - ID: pred.Spec.Protein.ID, - }, - }, - }, - ModelSeeds: pred.Spec.Model.Seeds, - Dialect: "alphafold3", - Version: 1, - } - if prediction { - empty := "" - emptyList := make([]string, 0) - input.Sequences[0].Protein.Templates = &emptyList - input.Sequences[0].Protein.UnpairedMSA = &empty - input.Sequences[0].Protein.PairedMSA = &empty - } - - inputJson, err := json.Marshal(input) - if err != nil { - return "", fmt.Errorf("failed to marshal fold input: %w", err) - } - - return base64.StdEncoding.EncodeToString(inputJson), nil -} - -type phaseResources struct { - cpu string - memory string -} - -var profileResourceMap = map[datav1.ProteinConformationPredictionProfile]map[string]phaseResources{ - datav1.ProteinConformationPredictionProfileSmall: { - "search": {cpu: "1500m", memory: "6Gi"}, - "predict": {cpu: "1", memory: "4Gi"}, - "upload": {cpu: "100m", memory: "256Mi"}, - }, - datav1.ProteinConformationPredictionProfileMedium: { - "search": {cpu: "3", memory: "12Gi"}, - "predict": {cpu: "1", memory: "8Gi"}, - "upload": {cpu: "100m", memory: "256Mi"}, - }, - datav1.ProteinConformationPredictionProfileLarge: { - "search": {cpu: "6", memory: "48Gi"}, - "predict": {cpu: "2", memory: "16Gi"}, - "upload": {cpu: "100m", memory: "256Mi"}, - }, -} - -func resourcesForPhase(profile datav1.ProteinConformationPredictionProfile, phase string) corev1.ResourceRequirements { - p := profile - if p == "" { - p = datav1.ProteinConformationPredictionProfileMedium - } - phases, ok := profileResourceMap[p] - if !ok { - phases = profileResourceMap[datav1.ProteinConformationPredictionProfileMedium] - } - pr := phases[phase] - return corev1.ResourceRequirements{ - Requests: corev1.ResourceList{ - corev1.ResourceCPU: resource.MustParse(pr.cpu), - corev1.ResourceMemory: resource.MustParse(pr.memory), - }, - } -} - -func (r *ProteinConformationPredictionReconciler) newSearchJob(pred *datav1.ProteinConformationPrediction, jobName, pvcName, encodedInput string) *batchv1.Job { - backoffLimit := int32(2) - - job := &batchv1.Job{ - ObjectMeta: metav1.ObjectMeta{ - Name: jobName, - Namespace: pred.Namespace, - Labels: map[string]string{ - "app": pred.Name, - "data.kubefold.io/prediction": pred.Name, - "data.kubefold.io/step": "search", - "app.kubernetes.io/name": "proteinconformationprediction-search", - "app.kubernetes.io/instance": pred.Name, - "app.kubernetes.io/managed-by": "kubefold-operator", - }, - }, - Spec: batchv1.JobSpec{ - BackoffLimit: &backoffLimit, - Template: corev1.PodTemplateSpec{ - ObjectMeta: metav1.ObjectMeta{ - Labels: map[string]string{ - "app": pred.Name, - "data.kubefold.io/prediction": pred.Name, - "data.kubefold.io/step": "search", - "app.kubernetes.io/name": "proteinconformationprediction-search", - "app.kubernetes.io/instance": pred.Name, - "app.kubernetes.io/managed-by": "kubefold-operator", - }, - }, - Spec: corev1.PodSpec{ - RestartPolicy: corev1.RestartPolicyNever, - InitContainers: []corev1.Container{ - { - Name: "input-placement", - Image: ManagerImage, - ImagePullPolicy: ManagerImagePullPolicy, - SecurityContext: &corev1.SecurityContext{ - AllowPrivilegeEscalation: &[]bool{false}[0], - Capabilities: &corev1.Capabilities{ - Drop: []corev1.Capability{"ALL"}, - }, - }, - Env: []corev1.EnvVar{ - { - Name: "INPUT_PATH", - Value: "/data/af_input", - }, - { - Name: "OUTPUT_PATH", - Value: "/data/af_output", - }, - { - Name: "ENCODED_INPUT", - Value: encodedInput, - }, - }, - VolumeMounts: []corev1.VolumeMount{ - { - Name: "data", - MountPath: "/data", - }, - { - Name: "database", - MountPath: "/public_databases", - }, - }, - }, - }, - Containers: []corev1.Container{ - { - Name: "search", - Image: AlphafoldImage, - ImagePullPolicy: AlphafoldImagePullPolicy, - SecurityContext: &corev1.SecurityContext{ - AllowPrivilegeEscalation: &[]bool{false}[0], - Capabilities: &corev1.Capabilities{ - Drop: []corev1.Capability{"ALL"}, - }, - }, - Resources: resourcesForPhase(pred.Spec.Job.Profile, "search"), - Command: []string{"uv"}, - Args: []string{ - "run", - "python3", - "run_alphafold.py", - "--json_path=/data/af_input/fold_input.json", - "--output_dir=/data/af_output", - "--model_dir=/data/models", - "--db_dir=/public_databases", - "--run_inference=false", - }, - VolumeMounts: []corev1.VolumeMount{ - { - Name: "data", - MountPath: "/data", - }, - { - Name: "database", - MountPath: "/public_databases", - }, - }, - }, - }, - Volumes: []corev1.Volume{ - { - Name: "data", - VolumeSource: corev1.VolumeSource{ - PersistentVolumeClaim: &corev1.PersistentVolumeClaimVolumeSource{ - ClaimName: pvcName, - }, - }, - }, - { - Name: "database", - VolumeSource: corev1.VolumeSource{ - PersistentVolumeClaim: &corev1.PersistentVolumeClaimVolumeSource{ - ClaimName: fmt.Sprintf("%s-data", pred.Spec.Database), - }, - }, - }, - }, - }, - }, - }, - } - - if pred.Spec.Job.SearchNodeSelector.NodeSelectorTerms != nil { - job.Spec.Template.Spec.NodeSelector = map[string]string{} - for _, term := range pred.Spec.Job.SearchNodeSelector.NodeSelectorTerms { - for _, exp := range term.MatchExpressions { - if exp.Operator == corev1.NodeSelectorOpIn && len(exp.Values) > 0 { - job.Spec.Template.Spec.NodeSelector[exp.Key] = exp.Values[0] - } - } - } - } - - return job -} - -func (r *ProteinConformationPredictionReconciler) newPredictionJob(pred *datav1.ProteinConformationPrediction, jobName, pvcName, encodedInput string) *batchv1.Job { - backoffLimit := int32(2) - - job := &batchv1.Job{ - ObjectMeta: metav1.ObjectMeta{ - Name: jobName, - Namespace: pred.Namespace, - Labels: map[string]string{ - "app": pred.Name, - "data.kubefold.io/prediction": pred.Name, - "data.kubefold.io/step": "predict", - "app.kubernetes.io/name": "proteinconformationprediction-predict", - "app.kubernetes.io/instance": pred.Name, - "app.kubernetes.io/managed-by": "kubefold-operator", - }, - }, - Spec: batchv1.JobSpec{ - BackoffLimit: &backoffLimit, - Template: corev1.PodTemplateSpec{ - ObjectMeta: metav1.ObjectMeta{ - Labels: map[string]string{ - "app": pred.Name, - "data.kubefold.io/prediction": pred.Name, - "data.kubefold.io/step": "predict", - "app.kubernetes.io/name": "proteinconformationprediction-predict", - "app.kubernetes.io/instance": pred.Name, - "app.kubernetes.io/managed-by": "kubefold-operator", - }, - }, - Spec: corev1.PodSpec{ - RestartPolicy: corev1.RestartPolicyNever, - InitContainers: []corev1.Container{ - { - Name: "input-placement", - Image: ManagerImage, - ImagePullPolicy: ManagerImagePullPolicy, - SecurityContext: &corev1.SecurityContext{ - AllowPrivilegeEscalation: &[]bool{false}[0], - Capabilities: &corev1.Capabilities{ - Drop: []corev1.Capability{"ALL"}, - }, - }, - Env: []corev1.EnvVar{ - { - Name: "INPUT_PATH", - Value: "/data/af_input", - }, - { - Name: "OUTPUT_PATH", - Value: "/data/af_output", - }, - { - Name: "ENCODED_INPUT", - Value: encodedInput, - }, - }, - VolumeMounts: []corev1.VolumeMount{ - { - Name: "data", - MountPath: "/data", - }, - { - Name: "database", - MountPath: "/public_databases", - }, - }, - }, - { - Name: "weights-placement", - Image: ManagerImage, - ImagePullPolicy: ManagerImagePullPolicy, - SecurityContext: &corev1.SecurityContext{ - AllowPrivilegeEscalation: &[]bool{false}[0], - Capabilities: &corev1.Capabilities{ - Drop: []corev1.Capability{"ALL"}, - }, - }, - Command: []string{ - "sh", - }, - Args: []string{ - "-c", - fmt.Sprintf("mkdir -p /data/models; wget --tries=3 --timeout=30 -O /data/models/af3.bin.zst %s && unzstd /data/models/af3.bin.zst || (echo 'Failed to download or extract weights' && exit 1)", pred.Spec.Model.Weights.HTTP), - }, - VolumeMounts: []corev1.VolumeMount{ - { - Name: "data", - MountPath: "/data", - }, - { - Name: "database", - MountPath: "/public_databases", - }, - }, - }, - }, - Containers: []corev1.Container{ - { - Name: "predict", - Image: AlphafoldImage, - ImagePullPolicy: AlphafoldImagePullPolicy, - SecurityContext: &corev1.SecurityContext{ - AllowPrivilegeEscalation: &[]bool{false}[0], - Capabilities: &corev1.Capabilities{ - Drop: []corev1.Capability{"ALL"}, - }, - }, - Resources: func() corev1.ResourceRequirements { - r := resourcesForPhase(pred.Spec.Job.Profile, "predict") - r.Requests["nvidia.com/gpu"] = resource.MustParse("1") - if r.Limits == nil { - r.Limits = corev1.ResourceList{} - } - r.Limits["nvidia.com/gpu"] = resource.MustParse("1") - return r - }(), - Command: []string{"uv"}, - Args: []string{ - "run", - "python3", - "run_alphafold.py", - "--json_path=/data/af_input/fold_input.json", - "--output_dir=/data/af_output", - "--model_dir=/data/models", - "--db_dir=/public_databases", - "--run_data_pipeline=false", - }, - VolumeMounts: []corev1.VolumeMount{ - { - Name: "data", - MountPath: "/data", - }, - { - Name: "database", - MountPath: "/public_databases", - }, - }, - }, - }, - Volumes: []corev1.Volume{ - { - Name: "data", - VolumeSource: corev1.VolumeSource{ - PersistentVolumeClaim: &corev1.PersistentVolumeClaimVolumeSource{ - ClaimName: pvcName, - }, - }, - }, - { - Name: "database", - VolumeSource: corev1.VolumeSource{ - PersistentVolumeClaim: &corev1.PersistentVolumeClaimVolumeSource{ - ClaimName: fmt.Sprintf("%s-data", pred.Spec.Database), - }, - }, - }, - }, - }, - }, - }, - } - - if pred.Spec.Job.PredictionNodeSelector.NodeSelectorTerms != nil { - job.Spec.Template.Spec.NodeSelector = map[string]string{} - for _, term := range pred.Spec.Job.PredictionNodeSelector.NodeSelectorTerms { - for _, exp := range term.MatchExpressions { - if exp.Operator == corev1.NodeSelectorOpIn && len(exp.Values) > 0 { - job.Spec.Template.Spec.NodeSelector[exp.Key] = exp.Values[0] - } - } - } - } - - return job -} - -func (r *ProteinConformationPredictionReconciler) newUploadArtifactsJob(pred *datav1.ProteinConformationPrediction, jobName, pvcName string) *batchv1.Job { - backoffLimit := int32(2) - - var phoneNumbers string - if len(pred.Spec.Notifications.SMS) > 0 { - phoneNumbers = pred.Spec.Notifications.SMS[0] - for i := 1; i < len(pred.Spec.Notifications.SMS); i++ { - phoneNumbers += "," + pred.Spec.Notifications.SMS[i] - } - } - - job := &batchv1.Job{ - ObjectMeta: metav1.ObjectMeta{ - Name: jobName, - Namespace: pred.Namespace, - Labels: map[string]string{ - "app": pred.Name, - "data.kubefold.io/prediction": pred.Name, - "data.kubefold.io/step": "upload", - "app.kubernetes.io/name": "proteinconformationprediction-upload", - "app.kubernetes.io/instance": pred.Name, - "app.kubernetes.io/managed-by": "kubefold-operator", - }, - }, - Spec: batchv1.JobSpec{ - BackoffLimit: &backoffLimit, - Template: corev1.PodTemplateSpec{ - ObjectMeta: metav1.ObjectMeta{ - Labels: map[string]string{ - "app": pred.Name, - "data.kubefold.io/prediction": pred.Name, - "data.kubefold.io/step": "upload", - "app.kubernetes.io/name": "proteinconformationprediction-upload", - "app.kubernetes.io/instance": pred.Name, - "app.kubernetes.io/managed-by": "kubefold-operator", - }, - }, - Spec: corev1.PodSpec{ - RestartPolicy: corev1.RestartPolicyNever, - Containers: []corev1.Container{ - { - Name: "upload", - Image: ManagerImage, - ImagePullPolicy: ManagerImagePullPolicy, - SecurityContext: &corev1.SecurityContext{ - AllowPrivilegeEscalation: &[]bool{false}[0], - Capabilities: &corev1.Capabilities{ - Drop: []corev1.Capability{"ALL"}, - }, - }, - Resources: resourcesForPhase(pred.Spec.Job.Profile, "upload"), - Env: []corev1.EnvVar{ - { - Name: "INPUT_PATH", - Value: "/data/af_input", - }, - { - Name: "OUTPUT_PATH", - Value: "/data/af_output", - }, - { - Name: "BUCKET", - Value: pred.Spec.Destination.S3.Bucket, - }, - { - Name: "AWS_REGION", - Value: pred.Spec.Destination.S3.Region, - }, - }, - VolumeMounts: []corev1.VolumeMount{ - { - Name: "data", - MountPath: "/data", - }, - }, - }, - { - Name: "notify", - Image: ManagerImage, - ImagePullPolicy: ManagerImagePullPolicy, - SecurityContext: &corev1.SecurityContext{ - AllowPrivilegeEscalation: &[]bool{false}[0], - Capabilities: &corev1.Capabilities{ - Drop: []corev1.Capability{"ALL"}, - }, - }, - Resources: resourcesForPhase(pred.Spec.Job.Profile, "upload"), - Env: []corev1.EnvVar{ - { - Name: "INPUT_PATH", - Value: "/data/af_input", - }, - { - Name: "OUTPUT_PATH", - Value: "/data/af_output", - }, - { - Name: "NOTIFICATION_PHONES", - Value: phoneNumbers, - }, - { - Name: "NOTIFICATION_MESSAGE", - Value: fmt.Sprintf("Protein Conformation Prediction %s in namespace %s completed. Artifacts has been uploaded to %s", pred.Name, pred.Namespace, pred.Spec.Destination.S3.Bucket), - }, - { - Name: "AWS_REGION", - Value: pred.Spec.Destination.S3.Region, - }, - }, - }, - }, - Volumes: []corev1.Volume{ - { - Name: "data", - VolumeSource: corev1.VolumeSource{ - PersistentVolumeClaim: &corev1.PersistentVolumeClaimVolumeSource{ - ClaimName: pvcName, - }, - }, - }, - }, - }, - }, - }, - } - - return job -} - -func (r *ProteinConformationPredictionReconciler) cleanupResources(ctx context.Context, pred *datav1.ProteinConformationPrediction) error { - log := logf.FromContext(ctx) - - pvcName := fmt.Sprintf("%s-data", pred.Name) - pvc := &corev1.PersistentVolumeClaim{} - err := r.Get(ctx, types.NamespacedName{Name: pvcName, Namespace: pred.Namespace}, pvc) - if err == nil { - if err := r.Delete(ctx, pvc); err != nil { - log.Error(err, "Failed to delete PVC", "Name", pvcName) - return err - } - } - - jobNames := []string{ - fmt.Sprintf("%s-search", pred.Name), - fmt.Sprintf("%s-predict", pred.Name), - fmt.Sprintf("%s-upload", pred.Name), - } - - for _, jobName := range jobNames { - job := &batchv1.Job{} - err := r.Get(ctx, types.NamespacedName{Name: jobName, Namespace: pred.Namespace}, job) - if err == nil { - if err := r.Delete(ctx, job); err != nil { - log.Error(err, "Failed to delete job", "Name", jobName) - return err - } - } - } - - return nil -} - -func (r *ProteinConformationPredictionReconciler) validateSpec(pred *datav1.ProteinConformationPrediction) error { - if pred.Spec.Protein.Sequence == "" { - return fmt.Errorf("protein sequence cannot be empty") - } - if pred.Spec.Database == "" { - return fmt.Errorf("database reference cannot be empty") - } - if pred.Spec.Destination.S3.Bucket == "" { - return fmt.Errorf("destination S3 bucket cannot be empty") - } - if pred.Spec.Destination.S3.Region == "" { - return fmt.Errorf("destination S3 region cannot be empty") - } - if pred.Spec.Model.Weights.HTTP == "" { - return fmt.Errorf("model weights HTTP URL cannot be empty") - } - return nil -} - -func (r *ProteinConformationPredictionReconciler) cleanupCompletedJobs(ctx context.Context, pred *datav1.ProteinConformationPrediction) error { - log := logf.FromContext(ctx) - - jobNames := []string{ - fmt.Sprintf("%s-search", pred.Name), - fmt.Sprintf("%s-predict", pred.Name), - fmt.Sprintf("%s-upload", pred.Name), - } - - for _, jobName := range jobNames { - job := &batchv1.Job{} - err := r.Get(ctx, types.NamespacedName{Name: jobName, Namespace: pred.Namespace}, job) - if err == nil && job.Status.Succeeded > 0 { - if err := r.Delete(ctx, job); err != nil { - log.Error(err, "Failed to delete completed job", "Name", jobName) - return err - } - } - } - - return nil -} - -func (r *ProteinConformationPredictionReconciler) checkJobTimeout(job *batchv1.Job) bool { - if job.Status.StartTime == nil { - return false - } - - timeout := DefaultJobTimeout - if job.Spec.ActiveDeadlineSeconds != nil { - timeout = time.Duration(*job.Spec.ActiveDeadlineSeconds) * time.Second - } - - return time.Since(job.Status.StartTime.Time) > timeout -} - -func (r *ProteinConformationPredictionReconciler) SetupWithManager(mgr ctrl.Manager) error { - r.Recorder = mgr.GetEventRecorderFor("proteinconformationprediction-controller") - return ctrl.NewControllerManagedBy(mgr). - For(&datav1.ProteinConformationPrediction{}). - Owns(&corev1.PersistentVolumeClaim{}). - Owns(&batchv1.Job{}). - Named("proteinconformationprediction"). - Complete(r) -} diff --git a/internal/controller/proteinconformationprediction_controller_test.go b/internal/controller/proteinconformationprediction_controller_test.go deleted file mode 100644 index 93d66a6..0000000 --- a/internal/controller/proteinconformationprediction_controller_test.go +++ /dev/null @@ -1,130 +0,0 @@ -/* -Copyright 2025 Mateusz Woźniak . - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package controller - -import ( - "context" - - . "github.com/onsi/ginkgo/v2" - . "github.com/onsi/gomega" - "k8s.io/apimachinery/pkg/api/errors" - "k8s.io/apimachinery/pkg/types" - "sigs.k8s.io/controller-runtime/pkg/reconcile" - - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - - datav1 "github.com/kubefold/operator/api/v1" - v1 "k8s.io/api/core/v1" -) - -var _ = Describe("ProteinConformationPrediction Controller", func() { - Context("When reconciling a resource", func() { - const resourceName = "test-resource" - - ctx := context.Background() - - typeNamespacedName := types.NamespacedName{ - Name: resourceName, - Namespace: "default", // TODO(user):Modify as needed - } - proteinconformationprediction := &datav1.ProteinConformationPrediction{} - - BeforeEach(func() { - By("creating the custom resource for the Kind ProteinConformationPrediction") - err := k8sClient.Get(ctx, typeNamespacedName, proteinconformationprediction) - if err != nil && errors.IsNotFound(err) { - resource := &datav1.ProteinConformationPrediction{ - ObjectMeta: metav1.ObjectMeta{ - Name: resourceName, - Namespace: "default", - }, - Spec: datav1.ProteinConformationPredictionSpec{ - Protein: datav1.ProteinConformationPredictionProtein{ - ID: []string{"test-id"}, - Sequence: "TESTSEQUENCE", - }, - Database: "test-database", - Destination: datav1.ProteinConformationPredictionDestination{ - S3: datav1.ProteinConformationPredictionDestinationS3{ - Bucket: "test-bucket", - Region: "test-region", - }, - }, - Model: datav1.ProteinConformationPredictionModel{ - Weights: datav1.ProteinConformationPredictionModelWeights{ - HTTP: "http://test-weights", - }, - }, - Job: datav1.ProteinConformationPredictionJob{ - SearchNodeSelector: v1.NodeSelector{ - NodeSelectorTerms: []v1.NodeSelectorTerm{ - { - MatchExpressions: []v1.NodeSelectorRequirement{ - { - Key: "test-key", - Operator: v1.NodeSelectorOpIn, - Values: []string{"test-value"}, - }, - }, - }, - }, - }, - PredictionNodeSelector: v1.NodeSelector{ - NodeSelectorTerms: []v1.NodeSelectorTerm{ - { - MatchExpressions: []v1.NodeSelectorRequirement{ - { - Key: "test-key", - Operator: v1.NodeSelectorOpIn, - Values: []string{"test-value"}, - }, - }, - }, - }, - }, - }, - }, - } - Expect(k8sClient.Create(ctx, resource)).To(Succeed()) - } - }) - - AfterEach(func() { - // TODO(user): Cleanup logic after each test, like removing the resource instance. - resource := &datav1.ProteinConformationPrediction{} - err := k8sClient.Get(ctx, typeNamespacedName, resource) - Expect(err).NotTo(HaveOccurred()) - - By("Cleanup the specific resource instance ProteinConformationPrediction") - Expect(k8sClient.Delete(ctx, resource)).To(Succeed()) - }) - It("should successfully reconcile the resource", func() { - By("Reconciling the created resource") - controllerReconciler := &ProteinConformationPredictionReconciler{ - Client: k8sClient, - Scheme: k8sClient.Scheme(), - } - - _, err := controllerReconciler.Reconcile(ctx, reconcile.Request{ - NamespacedName: typeNamespacedName, - }) - Expect(err).NotTo(HaveOccurred()) - // TODO(user): Add more specific assertions depending on your controller's reconciliation logic. - // Example: If you expect a certain status condition after reconciliation, verify it here. - }) - }) -}) diff --git a/internal/controller/proteindatabase_controller.go b/internal/controller/proteindatabase_controller.go deleted file mode 100644 index fe0d440..0000000 --- a/internal/controller/proteindatabase_controller.go +++ /dev/null @@ -1,102 +0,0 @@ -package controller - -import ( - "context" - - batchv1 "k8s.io/api/batch/v1" - corev1 "k8s.io/api/core/v1" - "k8s.io/apimachinery/pkg/api/errors" - "k8s.io/apimachinery/pkg/runtime" - ctrl "sigs.k8s.io/controller-runtime" - "sigs.k8s.io/controller-runtime/pkg/client" - "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" - logf "sigs.k8s.io/controller-runtime/pkg/log" - - datav1 "github.com/kubefold/operator/api/v1" -) - -type ProteinDatabaseReconciler struct { - client.Client - Scheme *runtime.Scheme - volumeHandler *VolumeHandler - jobHandler *JobHandler - finalizerHandler *FinalizerHandler -} - -// +kubebuilder:rbac:groups=data.kubefold.io,resources=proteindatabases,verbs=get;list;watch;create;update;patch;delete -// +kubebuilder:rbac:groups=data.kubefold.io,resources=proteindatabases/status,verbs=get;update;patch -// +kubebuilder:rbac:groups=data.kubefold.io,resources=proteindatabases/finalizers,verbs=update -// +kubebuilder:rbac:groups=core,resources=persistentvolumeclaims,verbs=get;list;watch;create;update;patch;delete -// +kubebuilder:rbac:groups=batch,resources=jobs,verbs=get;list;watch;create;update;patch;delete -// +kubebuilder:rbac:groups="",resources=pods,verbs=get;list;watch -// +kubebuilder:rbac:groups="",resources=pods/log,verbs=get - -//nolint:gocyclo -func (r *ProteinDatabaseReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { - log := logf.FromContext(ctx) - - pd := &datav1.ProteinDatabase{} - if err := r.Get(ctx, req.NamespacedName, pd); err != nil { - if errors.IsNotFound(err) { - return ctrl.Result{}, nil - } - log.Error(err, "Failed to get ProteinDatabase") - return ctrl.Result{}, err - } - - if r.volumeHandler == nil { - r.volumeHandler = &VolumeHandler{client: r.Client, scheme: r.Scheme} - } - if r.jobHandler == nil { - r.jobHandler = &JobHandler{client: r.Client, scheme: r.Scheme} - } - if r.finalizerHandler == nil { - r.finalizerHandler = &FinalizerHandler{client: r.Client, scheme: r.Scheme} - } - - if !pd.DeletionTimestamp.IsZero() { - return r.finalizerHandler.handleDeletion(ctx, pd) - } - - if !controllerutil.ContainsFinalizer(pd, ProteinDatabaseFinalizer) { - return r.finalizerHandler.ensureFinalizer(ctx, pd) - } - - pvc, result, err := r.volumeHandler.ensurePVC(ctx, pd) - if err != nil { - return ctrl.Result{}, err - } - if result != nil { - return *result, nil - } - - if pd.Status.VolumeName != pvc.Spec.VolumeName { - if err := r.updateStatus(ctx, pd, pvc); err != nil { - log.Error(err, "Failed to update ProteinDatabase status") - return ctrl.Result{}, err - } - } - - if err := r.jobHandler.ensureJobs(ctx, pd, pvc); err != nil { - log.Error(err, "Failed to ensure downloader jobs") - return ctrl.Result{}, err - } - - log.Info("Reconciliation completed successfully") - return ctrl.Result{}, nil -} - -func (r *ProteinDatabaseReconciler) SetupWithManager(mgr ctrl.Manager) error { - return ctrl.NewControllerManagedBy(mgr). - For(&datav1.ProteinDatabase{}). - Owns(&corev1.PersistentVolumeClaim{}). - Owns(&batchv1.Job{}). - Named("proteindatabase"). - Complete(r) -} - -func (r *ProteinDatabaseReconciler) updateStatus(ctx context.Context, pd *datav1.ProteinDatabase, pvc *corev1.PersistentVolumeClaim) error { - pdCopy := pd.DeepCopy() - pdCopy.Status.VolumeName = pvc.Spec.VolumeName - return r.Status().Update(ctx, pdCopy) -} diff --git a/internal/controller/proteindatabase_controller_test.go b/internal/controller/proteindatabase_controller_test.go deleted file mode 100644 index f49d961..0000000 --- a/internal/controller/proteindatabase_controller_test.go +++ /dev/null @@ -1,84 +0,0 @@ -/* -Copyright 2025 Mateusz Woźniak . - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package controller - -import ( - "context" - - . "github.com/onsi/ginkgo/v2" - . "github.com/onsi/gomega" - "k8s.io/apimachinery/pkg/api/errors" - "k8s.io/apimachinery/pkg/types" - "sigs.k8s.io/controller-runtime/pkg/reconcile" - - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - - datav1 "github.com/kubefold/operator/api/v1" -) - -var _ = Describe("ProteinDatabase Controller", func() { - Context("When reconciling a resource", func() { - const resourceName = "test-resource" - - ctx := context.Background() - - typeNamespacedName := types.NamespacedName{ - Name: resourceName, - Namespace: "default", // TODO(user):Modify as needed - } - proteindatabase := &datav1.ProteinDatabase{} - - BeforeEach(func() { - By("creating the custom resource for the Kind ProteinDatabase") - err := k8sClient.Get(ctx, typeNamespacedName, proteindatabase) - if err != nil && errors.IsNotFound(err) { - resource := &datav1.ProteinDatabase{ - ObjectMeta: metav1.ObjectMeta{ - Name: resourceName, - Namespace: "default", - }, - // TODO(user): Specify other spec details if needed. - } - Expect(k8sClient.Create(ctx, resource)).To(Succeed()) - } - }) - - AfterEach(func() { - // TODO(user): Cleanup logic after each test, like removing the resource instance. - resource := &datav1.ProteinDatabase{} - err := k8sClient.Get(ctx, typeNamespacedName, resource) - Expect(err).NotTo(HaveOccurred()) - - By("Cleanup the specific resource instance ProteinDatabase") - Expect(k8sClient.Delete(ctx, resource)).To(Succeed()) - }) - It("should successfully reconcile the resource", func() { - By("Reconciling the created resource") - controllerReconciler := &ProteinDatabaseReconciler{ - Client: k8sClient, - Scheme: k8sClient.Scheme(), - } - - _, err := controllerReconciler.Reconcile(ctx, reconcile.Request{ - NamespacedName: typeNamespacedName, - }) - Expect(err).NotTo(HaveOccurred()) - // TODO(user): Add more specific assertions depending on your controller's reconciliation logic. - // Example: If you expect a certain status condition after reconciliation, verify it here. - }) - }) -}) diff --git a/internal/controller/proteindatabase_finalizer_handler.go b/internal/controller/proteindatabase_finalizer_handler.go deleted file mode 100644 index 7c4a387..0000000 --- a/internal/controller/proteindatabase_finalizer_handler.go +++ /dev/null @@ -1,51 +0,0 @@ -package controller - -import ( - "context" - - "k8s.io/apimachinery/pkg/runtime" - ctrl "sigs.k8s.io/controller-runtime" - "sigs.k8s.io/controller-runtime/pkg/client" - "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" - logf "sigs.k8s.io/controller-runtime/pkg/log" - - datav1 "github.com/kubefold/operator/api/v1" -) - -type FinalizerHandler struct { - client client.Client - scheme *runtime.Scheme -} - -func (f *FinalizerHandler) handleDeletion(ctx context.Context, pd *datav1.ProteinDatabase) (ctrl.Result, error) { - log := logf.FromContext(ctx) - - if controllerutil.ContainsFinalizer(pd, ProteinDatabaseFinalizer) { - if err := f.cleanupResources(ctx, pd); err != nil { - log.Error(err, "Failed to clean up resources") - return ctrl.Result{}, err - } - - controllerutil.RemoveFinalizer(pd, ProteinDatabaseFinalizer) - if err := f.client.Update(ctx, pd); err != nil { - log.Error(err, "Failed to remove finalizer") - return ctrl.Result{}, err - } - } - return ctrl.Result{}, nil -} - -func (f *FinalizerHandler) ensureFinalizer(ctx context.Context, pd *datav1.ProteinDatabase) (ctrl.Result, error) { - log := logf.FromContext(ctx) - - controllerutil.AddFinalizer(pd, ProteinDatabaseFinalizer) - if err := f.client.Update(ctx, pd); err != nil { - log.Error(err, "Failed to add finalizer") - return ctrl.Result{}, err - } - return ctrl.Result{Requeue: true}, nil -} - -func (f *FinalizerHandler) cleanupResources(ctx context.Context, pd *datav1.ProteinDatabase) error { - return nil -} diff --git a/internal/controller/proteindatabase_job_handler.go b/internal/controller/proteindatabase_job_handler.go deleted file mode 100644 index 06952e1..0000000 --- a/internal/controller/proteindatabase_job_handler.go +++ /dev/null @@ -1,128 +0,0 @@ -package controller - -import ( - "context" - "fmt" - "strings" - - batchv1 "k8s.io/api/batch/v1" - corev1 "k8s.io/api/core/v1" - "k8s.io/apimachinery/pkg/api/errors" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/apimachinery/pkg/runtime" - "k8s.io/apimachinery/pkg/types" - "sigs.k8s.io/controller-runtime/pkg/client" - "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" - logf "sigs.k8s.io/controller-runtime/pkg/log" - - downloaderTypes "github.com/kubefold/downloader/pkg/types" - datav1 "github.com/kubefold/operator/api/v1" -) - -type JobHandler struct { - client client.Client - scheme *runtime.Scheme -} - -func (j *JobHandler) ensureJobs(ctx context.Context, pd *datav1.ProteinDatabase, pvc *corev1.PersistentVolumeClaim) error { - log := logf.FromContext(ctx) - - for _, dataset := range getDatasets(pd) { - if err := j.ensureJob(ctx, pd, pvc, dataset); err != nil { - log.Error(err, "Failed to ensure job for dataset", "dataset", dataset) - return err - } - } - - return nil -} - -func (j *JobHandler) ensureJob(ctx context.Context, pd *datav1.ProteinDatabase, pvc *corev1.PersistentVolumeClaim, dataset downloaderTypes.Dataset) error { - log := logf.FromContext(ctx) - jobName := strings.ReplaceAll(strings.ToLower(fmt.Sprintf("%s-%s-downloader", pd.Name, dataset.ShortName())), "_", "-") - - existingJob := &batchv1.Job{} - err := j.client.Get(ctx, types.NamespacedName{Name: jobName, Namespace: pd.Namespace}, existingJob) - if err == nil { - log.Info("Job already exists", "jobName", jobName) - return nil - } - - if !errors.IsNotFound(err) { - return err - } - - job := j.createJobSpec(pd, pvc, dataset, jobName) - - if err := controllerutil.SetControllerReference(pd, job, j.scheme); err != nil { - return err - } - - log.Info("Creating downloader job", "jobName", jobName, "dataset", dataset) - if err := j.client.Create(ctx, job); err != nil { - return err - } - - return nil -} - -func (j *JobHandler) createJobSpec(pd *datav1.ProteinDatabase, pvc *corev1.PersistentVolumeClaim, dataset downloaderTypes.Dataset, jobName string) *batchv1.Job { - labels := map[string]string{ - "data.kubefold.io/dataset": dataset.ShortName(), - "data.kubefold.io/database": pd.Name, - "app.kubernetes.io/name": "proteindatabase-downloader", - "app.kubernetes.io/instance": pd.Name, - "app.kubernetes.io/managed-by": "kubefold-operator", - } - - return &batchv1.Job{ - ObjectMeta: metav1.ObjectMeta{ - Name: jobName, - Namespace: pd.Namespace, - Labels: labels, - }, - Spec: batchv1.JobSpec{ - Template: corev1.PodTemplateSpec{ - ObjectMeta: metav1.ObjectMeta{ - Labels: labels, - }, - Spec: corev1.PodSpec{ - RestartPolicy: corev1.RestartPolicyOnFailure, - Containers: []corev1.Container{ - { - Name: "downloader", - Image: DownloaderImage, - ImagePullPolicy: DownloaderImagePullPolicy, - Env: []corev1.EnvVar{ - { - Name: "DATASET", - Value: dataset.String(), - }, - { - Name: "DESTINATION", - Value: "/public_databases", - }, - }, - VolumeMounts: []corev1.VolumeMount{ - { - Name: "databases", - MountPath: "/public_databases", - }, - }, - }, - }, - Volumes: []corev1.Volume{ - { - Name: "databases", - VolumeSource: corev1.VolumeSource{ - PersistentVolumeClaim: &corev1.PersistentVolumeClaimVolumeSource{ - ClaimName: pvc.Name, - }, - }, - }, - }, - }, - }, - }, - } -} diff --git a/internal/controller/proteindatabase_utils.go b/internal/controller/proteindatabase_utils.go deleted file mode 100644 index 1daa879..0000000 --- a/internal/controller/proteindatabase_utils.go +++ /dev/null @@ -1,38 +0,0 @@ -package controller - -import ( - "github.com/kubefold/downloader/pkg/types" - datav1 "github.com/kubefold/operator/api/v1" -) - -func getDatasets(pd *datav1.ProteinDatabase) []types.Dataset { - o := make([]types.Dataset, 0) - if pd.Spec.Datasets.MGYClusters { - o = append(o, types.DatasetMGYClusters) - } - if pd.Spec.Datasets.BFD { - o = append(o, types.DatasetBFD) - } - if pd.Spec.Datasets.UniRef90 { - o = append(o, types.DatasetUniRef90) - } - if pd.Spec.Datasets.UniProt { - o = append(o, types.DatasetUniProt) - } - if pd.Spec.Datasets.PDB { - o = append(o, types.DatasetPDB) - } - if pd.Spec.Datasets.PDBSeqReq { - o = append(o, types.DatasetPDBSeqReq) - } - if pd.Spec.Datasets.RNACentral { - o = append(o, types.DatasetRNACentral) - } - if pd.Spec.Datasets.NT { - o = append(o, types.DatasetNT) - } - if pd.Spec.Datasets.RFam { - o = append(o, types.DatasetRFam) - } - return o -} diff --git a/internal/controller/proteindatabase_volume_handler.go b/internal/controller/proteindatabase_volume_handler.go deleted file mode 100644 index 2a17003..0000000 --- a/internal/controller/proteindatabase_volume_handler.go +++ /dev/null @@ -1,128 +0,0 @@ -package controller - -import ( - "context" - "fmt" - - downloaderTypes "github.com/kubefold/downloader/pkg/types" - corev1 "k8s.io/api/core/v1" - "k8s.io/apimachinery/pkg/api/errors" - "k8s.io/apimachinery/pkg/api/resource" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/apimachinery/pkg/runtime" - "k8s.io/apimachinery/pkg/types" - ctrl "sigs.k8s.io/controller-runtime" - "sigs.k8s.io/controller-runtime/pkg/client" - "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" - logf "sigs.k8s.io/controller-runtime/pkg/log" - - datav1 "github.com/kubefold/operator/api/v1" -) - -type VolumeHandler struct { - client client.Client - scheme *runtime.Scheme -} - -func (v *VolumeHandler) ensurePVC(ctx context.Context, pd *datav1.ProteinDatabase) (*corev1.PersistentVolumeClaim, *ctrl.Result, error) { - log := logf.FromContext(ctx) - pvcName := pd.Name + PersistentVolumeClaimNameSuffix - - pvc := &corev1.PersistentVolumeClaim{} - err := v.client.Get(ctx, types.NamespacedName{Name: pvcName, Namespace: pd.Namespace}, pvc) - - if err != nil && errors.IsNotFound(err) { - pvc, err = v.createPVC(ctx, pd) - if err != nil { - log.Error(err, "Failed to create PVC") - return nil, nil, err - } - log.Info("Created new PVC", "pvcName", pvc.Name) - } else if err != nil { - log.Error(err, "Failed to get PVC") - return nil, nil, err - } - - if pvc.Status.Phase != corev1.ClaimBound { - log.Info("PVC is not bound yet", "pvcName", pvc.Name, "phase", pvc.Status.Phase) - result := ctrl.Result{Requeue: true, RequeueAfter: ReconcileInterval} - return pvc, &result, nil - } - - return pvc, nil, nil -} - -func (v *VolumeHandler) createPVC(ctx context.Context, pd *datav1.ProteinDatabase) (*corev1.PersistentVolumeClaim, error) { - pvcName := pd.Name + PersistentVolumeClaimNameSuffix - - labels := pd.Spec.Volume.Labels - if labels == nil { - labels = make(map[string]string) - } - labels["data.kubefold.io/database"] = pd.Name - labels["app.kubernetes.io/name"] = "proteindatabase" - labels["app.kubernetes.io/instance"] = pd.Name - labels["app.kubernetes.io/managed-by"] = "kubefold-operator" - - var requestedSize int64 - if pd.Spec.Datasets.MGYClusters { - requestedSize += downloaderTypes.DatasetMGYClusters.Size() - } - if pd.Spec.Datasets.BFD { - requestedSize += downloaderTypes.DatasetBFD.Size() - } - if pd.Spec.Datasets.UniRef90 { - requestedSize += downloaderTypes.DatasetUniRef90.Size() - } - if pd.Spec.Datasets.UniProt { - requestedSize += downloaderTypes.DatasetUniProt.Size() - } - if pd.Spec.Datasets.PDB { - requestedSize += downloaderTypes.DatasetPDB.Size() - } - if pd.Spec.Datasets.PDBSeqReq { - requestedSize += downloaderTypes.DatasetPDBSeqReq.Size() - } - if pd.Spec.Datasets.RNACentral { - requestedSize += downloaderTypes.DatasetRNACentral.Size() - } - if pd.Spec.Datasets.NT { - requestedSize += downloaderTypes.DatasetNT.Size() - } - if pd.Spec.Datasets.RFam { - requestedSize += downloaderTypes.DatasetRFam.Size() - } - requestedSizeGigabytes := requestedSize / 1024 / 1024 / 1024 - requestedSizeGigabytes += 2 - - pvc := &corev1.PersistentVolumeClaim{ - ObjectMeta: metav1.ObjectMeta{ - Name: pvcName, - Namespace: pd.Namespace, - Labels: labels, - }, - Spec: corev1.PersistentVolumeClaimSpec{ - AccessModes: []corev1.PersistentVolumeAccessMode{corev1.ReadWriteMany}, - StorageClassName: pd.Spec.Volume.StorageClassName, - Resources: corev1.VolumeResourceRequirements{ - Requests: corev1.ResourceList{ - corev1.ResourceStorage: resource.MustParse(fmt.Sprintf("%dGi", requestedSizeGigabytes)), - }, - }, - }, - } - - if pd.Spec.Volume.Selector != nil { - pvc.Spec.Selector = pd.Spec.Volume.Selector - } - - if err := controllerutil.SetControllerReference(pd, pvc, v.scheme); err != nil { - return nil, err - } - - if err := v.client.Create(ctx, pvc); err != nil { - return nil, err - } - - return pvc, nil -} diff --git a/internal/controller/suite_test.go b/internal/controller/suite_test.go deleted file mode 100644 index 8980c75..0000000 --- a/internal/controller/suite_test.go +++ /dev/null @@ -1,116 +0,0 @@ -/* -Copyright 2025 Mateusz Woźniak . - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package controller - -import ( - "context" - "os" - "path/filepath" - "testing" - - . "github.com/onsi/ginkgo/v2" - . "github.com/onsi/gomega" - - "k8s.io/client-go/kubernetes/scheme" - "k8s.io/client-go/rest" - "sigs.k8s.io/controller-runtime/pkg/client" - "sigs.k8s.io/controller-runtime/pkg/envtest" - logf "sigs.k8s.io/controller-runtime/pkg/log" - "sigs.k8s.io/controller-runtime/pkg/log/zap" - - datav1 "github.com/kubefold/operator/api/v1" - // +kubebuilder:scaffold:imports -) - -// These tests use Ginkgo (BDD-style Go testing framework). Refer to -// http://onsi.github.io/ginkgo/ to learn more about Ginkgo. - -var ( - ctx context.Context - cancel context.CancelFunc - testEnv *envtest.Environment - cfg *rest.Config - k8sClient client.Client -) - -func TestControllers(t *testing.T) { - RegisterFailHandler(Fail) - - RunSpecs(t, "Controller Suite") -} - -var _ = BeforeSuite(func() { - logf.SetLogger(zap.New(zap.WriteTo(GinkgoWriter), zap.UseDevMode(true))) - - ctx, cancel = context.WithCancel(context.TODO()) - - var err error - err = datav1.AddToScheme(scheme.Scheme) - Expect(err).NotTo(HaveOccurred()) - - // +kubebuilder:scaffold:scheme - - By("bootstrapping test environment") - testEnv = &envtest.Environment{ - CRDDirectoryPaths: []string{filepath.Join("..", "..", "config", "crd", "bases")}, - ErrorIfCRDPathMissing: true, - } - - // Retrieve the first found binary directory to allow running tests from IDEs - if getFirstFoundEnvTestBinaryDir() != "" { - testEnv.BinaryAssetsDirectory = getFirstFoundEnvTestBinaryDir() - } - - // cfg is defined in this file globally. - cfg, err = testEnv.Start() - Expect(err).NotTo(HaveOccurred()) - Expect(cfg).NotTo(BeNil()) - - k8sClient, err = client.New(cfg, client.Options{Scheme: scheme.Scheme}) - Expect(err).NotTo(HaveOccurred()) - Expect(k8sClient).NotTo(BeNil()) -}) - -var _ = AfterSuite(func() { - By("tearing down the test environment") - cancel() - err := testEnv.Stop() - Expect(err).NotTo(HaveOccurred()) -}) - -// getFirstFoundEnvTestBinaryDir locates the first binary in the specified path. -// ENVTEST-based tests depend on specific binaries, usually located in paths set by -// controller-runtime. When running tests directly (e.g., via an IDE) without using -// Makefile targets, the 'BinaryAssetsDirectory' must be explicitly configured. -// -// This function streamlines the process by finding the required binaries, similar to -// setting the 'KUBEBUILDER_ASSETS' environment variable. To ensure the binaries are -// properly set up, run 'make setup-envtest' beforehand. -func getFirstFoundEnvTestBinaryDir() string { - basePath := filepath.Join("..", "..", "bin", "k8s") - entries, err := os.ReadDir(basePath) - if err != nil { - logf.Log.Error(err, "Failed to read directory", "path", basePath) - return "" - } - for _, entry := range entries { - if entry.IsDir() { - return filepath.Join(basePath, entry.Name()) - } - } - return "" -} diff --git a/internal/database/controller.go b/internal/database/controller.go new file mode 100644 index 0000000..b6de9d7 --- /dev/null +++ b/internal/database/controller.go @@ -0,0 +1,115 @@ +package database + +import ( + "context" + + batchv1 "k8s.io/api/batch/v1" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/types" + "k8s.io/client-go/tools/record" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/client" + logf "sigs.k8s.io/controller-runtime/pkg/log" + + datav1 "github.com/kubefold/operator/api/v1" + "github.com/kubefold/operator/internal/shared" +) + +type Reconciler struct { + client.Client + Scheme *runtime.Scheme + Recorder record.EventRecorder + volume VolumeReconciler + jobs JobReconciler + finalizer FinalizerReconciler +} + +func NewReconciler(c client.Client, scheme *runtime.Scheme, recorder record.EventRecorder) *Reconciler { + enumerator := NewDatasetEnumerator() + sizer := NewSizer(enumerator) + return &Reconciler{ + Client: c, + Scheme: scheme, + Recorder: recorder, + volume: NewVolumeReconciler(c, scheme, sizer), + jobs: NewJobReconciler(c, scheme, enumerator), + finalizer: NewFinalizerReconciler(c), + } +} + +// +kubebuilder:rbac:groups=data.kubefold.io,resources=proteindatabases,verbs=get;list;watch;create;update;patch;delete +// +kubebuilder:rbac:groups=data.kubefold.io,resources=proteindatabases/status,verbs=get;update;patch +// +kubebuilder:rbac:groups=data.kubefold.io,resources=proteindatabases/finalizers,verbs=update +// +kubebuilder:rbac:groups=core,resources=persistentvolumeclaims,verbs=get;list;watch;create;update;patch;delete +// +kubebuilder:rbac:groups=batch,resources=jobs,verbs=get;list;watch;create;update;patch;delete +// +kubebuilder:rbac:groups="",resources=pods,verbs=get;list;watch +// +kubebuilder:rbac:groups="",resources=pods/log,verbs=get + +func (r *Reconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { + log := logf.FromContext(ctx) + + database := &datav1.ProteinDatabase{} + if err := r.Get(ctx, req.NamespacedName, database); err != nil { + if errors.IsNotFound(err) { + return ctrl.Result{}, nil + } + log.Error(err, "Failed to get ProteinDatabase") + return ctrl.Result{}, err + } + + if !database.DeletionTimestamp.IsZero() { + return r.finalizer.HandleDeletion(ctx, database) + } + + if result, err := r.finalizer.Ensure(ctx, database); err != nil || !result.IsZero() { + return result, err + } + + pvc, result, err := r.volume.Ensure(ctx, database) + if err != nil { + return ctrl.Result{}, err + } + if result != nil { + return *result, nil + } + + if database.Status.VolumeName != pvc.Spec.VolumeName { + if err := r.updateVolumeName(ctx, database, pvc); err != nil { + log.Error(err, "Failed to update ProteinDatabase status") + return ctrl.Result{}, err + } + } + + if err := r.jobs.EnsureAll(ctx, database, pvc); err != nil { + log.Error(err, "Failed to ensure downloader jobs") + return ctrl.Result{}, err + } + + log.Info("Reconciliation completed successfully") + return ctrl.Result{}, nil +} + +func (r *Reconciler) updateVolumeName(ctx context.Context, database *datav1.ProteinDatabase, pvc *corev1.PersistentVolumeClaim) error { + return shared.RetryOnConflict(ctx, func() error { + latest := &datav1.ProteinDatabase{} + if err := r.Get(ctx, types.NamespacedName{Name: database.Name, Namespace: database.Namespace}, latest); err != nil { + return err + } + latest.Status.VolumeName = pvc.Spec.VolumeName + return r.Status().Update(ctx, latest) + }) +} + +func (r *Reconciler) SetupWithManager(mgr ctrl.Manager) error { + if r.Recorder == nil { + r.Recorder = mgr.GetEventRecorderFor("proteindatabase-controller") + } + return ctrl.NewControllerManagedBy(mgr). + For(&datav1.ProteinDatabase{}). + Owns(&corev1.PersistentVolumeClaim{}). + Owns(&batchv1.Job{}). + Named("proteindatabase"). + Complete(r) +} diff --git a/internal/database/datasets.go b/internal/database/datasets.go new file mode 100644 index 0000000..cdbe869 --- /dev/null +++ b/internal/database/datasets.go @@ -0,0 +1,53 @@ +package database + +import ( + downloaderTypes "github.com/kubefold/downloader/pkg/types" + + datav1 "github.com/kubefold/operator/api/v1" +) + +type DatasetEnumerator interface { + FromSpec(database *datav1.ProteinDatabase) []downloaderTypes.Dataset + All() []downloaderTypes.Dataset +} + +type datasetEnumerator struct{} + +func NewDatasetEnumerator() DatasetEnumerator { + return &datasetEnumerator{} +} + +type datasetEntry struct { + dataset downloaderTypes.Dataset + enabled func(*datav1.ProteinDatabase) bool +} + +var datasetEntries = []datasetEntry{ + {downloaderTypes.DatasetMGYClusters, func(p *datav1.ProteinDatabase) bool { return p.Spec.Datasets.MGYClusters }}, + {downloaderTypes.DatasetBFD, func(p *datav1.ProteinDatabase) bool { return p.Spec.Datasets.BFD }}, + {downloaderTypes.DatasetUniRef90, func(p *datav1.ProteinDatabase) bool { return p.Spec.Datasets.UniRef90 }}, + {downloaderTypes.DatasetUniProt, func(p *datav1.ProteinDatabase) bool { return p.Spec.Datasets.UniProt }}, + {downloaderTypes.DatasetPDB, func(p *datav1.ProteinDatabase) bool { return p.Spec.Datasets.PDB }}, + {downloaderTypes.DatasetPDBSeqReq, func(p *datav1.ProteinDatabase) bool { return p.Spec.Datasets.PDBSeqReq }}, + {downloaderTypes.DatasetRNACentral, func(p *datav1.ProteinDatabase) bool { return p.Spec.Datasets.RNACentral }}, + {downloaderTypes.DatasetNT, func(p *datav1.ProteinDatabase) bool { return p.Spec.Datasets.NT }}, + {downloaderTypes.DatasetRFam, func(p *datav1.ProteinDatabase) bool { return p.Spec.Datasets.RFam }}, +} + +func (e *datasetEnumerator) FromSpec(database *datav1.ProteinDatabase) []downloaderTypes.Dataset { + out := make([]downloaderTypes.Dataset, 0, len(datasetEntries)) + for _, entry := range datasetEntries { + if entry.enabled(database) { + out = append(out, entry.dataset) + } + } + return out +} + +func (e *datasetEnumerator) All() []downloaderTypes.Dataset { + out := make([]downloaderTypes.Dataset, 0, len(datasetEntries)) + for _, entry := range datasetEntries { + out = append(out, entry.dataset) + } + return out +} diff --git a/internal/database/datasets_test.go b/internal/database/datasets_test.go new file mode 100644 index 0000000..e2fecf0 --- /dev/null +++ b/internal/database/datasets_test.go @@ -0,0 +1,51 @@ +package database + +import ( + "slices" + "testing" + + downloaderTypes "github.com/kubefold/downloader/pkg/types" + + datav1 "github.com/kubefold/operator/api/v1" +) + +func TestDatasetEnumeratorAll(t *testing.T) { + enumerator := NewDatasetEnumerator() + all := enumerator.All() + if len(all) != 9 { + t.Fatalf("expected 9 datasets, got %d", len(all)) + } +} + +func TestDatasetEnumeratorFromSpec(t *testing.T) { + enumerator := NewDatasetEnumerator() + database := &datav1.ProteinDatabase{ + Spec: datav1.ProteinDatabaseSpec{ + Datasets: datav1.ProteinDatabaseDatasetSelection{ + BFD: true, UniRef90: true, RFam: true, + }, + }, + } + got := enumerator.FromSpec(database) + wantContains := []downloaderTypes.Dataset{ + downloaderTypes.DatasetBFD, + downloaderTypes.DatasetUniRef90, + downloaderTypes.DatasetRFam, + } + if len(got) != 3 { + t.Fatalf("expected 3 datasets, got %d", len(got)) + } + for _, expected := range wantContains { + if !slices.Contains(got, expected) { + t.Fatalf("expected dataset %v in result", expected) + } + } +} + +func TestDatasetEnumeratorEmpty(t *testing.T) { + enumerator := NewDatasetEnumerator() + got := enumerator.FromSpec(&datav1.ProteinDatabase{}) + if len(got) != 0 { + t.Fatalf("expected empty result for empty spec, got %d", len(got)) + } +} diff --git a/internal/database/finalizer.go b/internal/database/finalizer.go new file mode 100644 index 0000000..88a6a4b --- /dev/null +++ b/internal/database/finalizer.go @@ -0,0 +1,96 @@ +package database + +import ( + "context" + + batchv1 "k8s.io/api/batch/v1" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/types" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" + logf "sigs.k8s.io/controller-runtime/pkg/log" + + datav1 "github.com/kubefold/operator/api/v1" + "github.com/kubefold/operator/internal/shared" +) + +const Finalizer = "data.kubefold.io/finalizer" + +type FinalizerReconciler interface { + HandleDeletion(ctx context.Context, database *datav1.ProteinDatabase) (ctrl.Result, error) + Ensure(ctx context.Context, database *datav1.ProteinDatabase) (ctrl.Result, error) +} + +type finalizerReconciler struct { + client client.Client +} + +func NewFinalizerReconciler(c client.Client) FinalizerReconciler { + return &finalizerReconciler{client: c} +} + +func (f *finalizerReconciler) HandleDeletion(ctx context.Context, database *datav1.ProteinDatabase) (ctrl.Result, error) { + log := logf.FromContext(ctx) + if !controllerutil.ContainsFinalizer(database, Finalizer) { + return ctrl.Result{}, nil + } + if err := f.deleteJobs(ctx, database); err != nil { + log.Error(err, "Failed to clean up downloader jobs") + return ctrl.Result{}, err + } + if err := f.deletePVC(ctx, database); err != nil { + log.Error(err, "Failed to clean up PVC") + return ctrl.Result{}, err + } + controllerutil.RemoveFinalizer(database, Finalizer) + if err := f.client.Update(ctx, database); err != nil { + log.Error(err, "Failed to remove finalizer") + return ctrl.Result{}, err + } + return ctrl.Result{}, nil +} + +func (f *finalizerReconciler) Ensure(ctx context.Context, database *datav1.ProteinDatabase) (ctrl.Result, error) { + log := logf.FromContext(ctx) + if controllerutil.ContainsFinalizer(database, Finalizer) { + return ctrl.Result{}, nil + } + controllerutil.AddFinalizer(database, Finalizer) + if err := f.client.Update(ctx, database); err != nil { + log.Error(err, "Failed to add finalizer") + return ctrl.Result{}, err + } + return ctrl.Result{Requeue: true}, nil +} + +func (f *finalizerReconciler) deleteJobs(ctx context.Context, database *datav1.ProteinDatabase) error { + jobs := &batchv1.JobList{} + if err := f.client.List(ctx, jobs, + client.InNamespace(database.Namespace), + client.MatchingLabels{shared.LabelDatabase: database.Name}, + ); err != nil { + return err + } + for i := range jobs.Items { + if err := shared.DeleteInBackground(ctx, f.client, &jobs.Items[i]); err != nil { + return err + } + } + return nil +} + +func (f *finalizerReconciler) deletePVC(ctx context.Context, database *datav1.ProteinDatabase) error { + if database.Spec.Volume.Selector != nil { + return nil + } + pvc := &corev1.PersistentVolumeClaim{} + if err := f.client.Get(ctx, types.NamespacedName{Name: shared.DatabasePVCName(database.Name), Namespace: database.Namespace}, pvc); err != nil { + if errors.IsNotFound(err) { + return nil + } + return err + } + return shared.DeleteInBackground(ctx, f.client, pvc) +} diff --git a/internal/database/finalizer_test.go b/internal/database/finalizer_test.go new file mode 100644 index 0000000..fb9066f --- /dev/null +++ b/internal/database/finalizer_test.go @@ -0,0 +1,101 @@ +package database + +import ( + "context" + "testing" + + batchv1 "k8s.io/api/batch/v1" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/client/fake" + + datav1 "github.com/kubefold/operator/api/v1" + "github.com/kubefold/operator/internal/shared" +) + +func newScheme(t *testing.T) *runtime.Scheme { + t.Helper() + scheme := runtime.NewScheme() + if err := datav1.AddToScheme(scheme); err != nil { + t.Fatalf("AddToScheme: %v", err) + } + if err := corev1.AddToScheme(scheme); err != nil { + t.Fatalf("AddToScheme corev1: %v", err) + } + if err := batchv1.AddToScheme(scheme); err != nil { + t.Fatalf("AddToScheme batchv1: %v", err) + } + return scheme +} + +func TestFinalizerHandleDeletionRemovesLabeledJobsAndPVC(t *testing.T) { + scheme := newScheme(t) + database := &datav1.ProteinDatabase{ + ObjectMeta: metav1.ObjectMeta{ + Name: "alpha", + Namespace: "default", + Finalizers: []string{Finalizer}, + }, + } + job := &batchv1.Job{ + ObjectMeta: metav1.ObjectMeta{ + Name: "alpha-bfd-downloader", + Namespace: "default", + Labels: map[string]string{shared.LabelDatabase: "alpha"}, + }, + } + pvc := &corev1.PersistentVolumeClaim{ + ObjectMeta: metav1.ObjectMeta{ + Name: shared.DatabasePVCName("alpha"), + Namespace: "default", + }, + } + c := fake.NewClientBuilder().WithScheme(scheme).WithObjects(database, job, pvc).Build() + + reconciler := NewFinalizerReconciler(c) + if _, err := reconciler.HandleDeletion(context.Background(), database); err != nil { + t.Fatalf("HandleDeletion failed: %v", err) + } + + if err := c.Get(context.Background(), types.NamespacedName{Name: job.Name, Namespace: job.Namespace}, &batchv1.Job{}); !errors.IsNotFound(err) { + t.Fatalf("expected job to be deleted, got: %v", err) + } + if err := c.Get(context.Background(), types.NamespacedName{Name: pvc.Name, Namespace: pvc.Namespace}, &corev1.PersistentVolumeClaim{}); !errors.IsNotFound(err) { + t.Fatalf("expected pvc to be deleted, got: %v", err) + } +} + +func TestFinalizerHandleDeletionSkipsPVCWithPreboundSelector(t *testing.T) { + scheme := newScheme(t) + database := &datav1.ProteinDatabase{ + ObjectMeta: metav1.ObjectMeta{ + Name: "beta", + Namespace: "default", + Finalizers: []string{Finalizer}, + }, + Spec: datav1.ProteinDatabaseSpec{ + Volume: datav1.ProteinDatabaseVolume{ + Selector: &metav1.LabelSelector{MatchLabels: map[string]string{"role": "static"}}, + }, + }, + } + pvc := &corev1.PersistentVolumeClaim{ + ObjectMeta: metav1.ObjectMeta{ + Name: shared.DatabasePVCName("beta"), + Namespace: "default", + }, + } + c := fake.NewClientBuilder().WithScheme(scheme).WithObjects(database, pvc).Build() + + reconciler := NewFinalizerReconciler(c) + if _, err := reconciler.HandleDeletion(context.Background(), database); err != nil { + t.Fatalf("HandleDeletion failed: %v", err) + } + + if err := c.Get(context.Background(), types.NamespacedName{Name: pvc.Name, Namespace: pvc.Namespace}, &corev1.PersistentVolumeClaim{}); err != nil { + t.Fatalf("PVC bound to pre-existing PV must not be deleted: %v", err) + } +} diff --git a/internal/database/job.go b/internal/database/job.go new file mode 100644 index 0000000..083b400 --- /dev/null +++ b/internal/database/job.go @@ -0,0 +1,122 @@ +package database + +import ( + "context" + "strings" + + downloaderTypes "github.com/kubefold/downloader/pkg/types" + batchv1 "k8s.io/api/batch/v1" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" + logf "sigs.k8s.io/controller-runtime/pkg/log" + + datav1 "github.com/kubefold/operator/api/v1" + "github.com/kubefold/operator/internal/shared" +) + +const ( + DownloaderImage = "ghcr.io/kubefold/downloader" + databaseMountPath = "/public_databases" + downloaderContainer = "downloader" +) + +type JobReconciler interface { + EnsureAll(ctx context.Context, database *datav1.ProteinDatabase, pvc *corev1.PersistentVolumeClaim) error +} + +type jobReconciler struct { + client client.Client + scheme *runtime.Scheme + enumerator DatasetEnumerator +} + +func NewJobReconciler(c client.Client, scheme *runtime.Scheme, enumerator DatasetEnumerator) JobReconciler { + return &jobReconciler{client: c, scheme: scheme, enumerator: enumerator} +} + +func (j *jobReconciler) EnsureAll(ctx context.Context, database *datav1.ProteinDatabase, pvc *corev1.PersistentVolumeClaim) error { + log := logf.FromContext(ctx) + for _, dataset := range j.enumerator.FromSpec(database) { + if err := j.ensureSingle(ctx, database, pvc, dataset); err != nil { + log.Error(err, "Failed to ensure job for dataset", "dataset", dataset) + return err + } + } + return nil +} + +func (j *jobReconciler) ensureSingle(ctx context.Context, database *datav1.ProteinDatabase, pvc *corev1.PersistentVolumeClaim, dataset downloaderTypes.Dataset) error { + log := logf.FromContext(ctx) + jobName := normalizeJobName(shared.DownloaderJobName(database.Name, dataset.ShortName())) + + existing := &batchv1.Job{} + err := j.client.Get(ctx, types.NamespacedName{Name: jobName, Namespace: database.Namespace}, existing) + if err == nil { + return nil + } + if !errors.IsNotFound(err) { + return err + } + + job := j.build(database, pvc, dataset, jobName) + if err := controllerutil.SetControllerReference(database, job, j.scheme); err != nil { + return err + } + log.Info("Creating downloader job", "jobName", jobName, "dataset", dataset) + if err := j.client.Create(ctx, job); err != nil && !errors.IsAlreadyExists(err) { + return err + } + return nil +} + +func (j *jobReconciler) build(database *datav1.ProteinDatabase, pvc *corev1.PersistentVolumeClaim, dataset downloaderTypes.Dataset, jobName string) *batchv1.Job { + labels := shared.DownloaderJobLabels(database.Name, dataset.ShortName()) + return &batchv1.Job{ + ObjectMeta: metav1.ObjectMeta{ + Name: jobName, + Namespace: database.Namespace, + Labels: labels, + }, + Spec: batchv1.JobSpec{ + Template: corev1.PodTemplateSpec{ + ObjectMeta: metav1.ObjectMeta{Labels: labels}, + Spec: corev1.PodSpec{ + RestartPolicy: corev1.RestartPolicyOnFailure, + Containers: []corev1.Container{ + { + Name: downloaderContainer, + Image: DownloaderImage, + ImagePullPolicy: corev1.PullAlways, + Env: []corev1.EnvVar{ + {Name: "DATASET", Value: dataset.String()}, + {Name: "DESTINATION", Value: databaseMountPath}, + }, + VolumeMounts: []corev1.VolumeMount{ + {Name: "databases", MountPath: databaseMountPath}, + }, + }, + }, + Volumes: []corev1.Volume{ + { + Name: "databases", + VolumeSource: corev1.VolumeSource{ + PersistentVolumeClaim: &corev1.PersistentVolumeClaimVolumeSource{ + ClaimName: pvc.Name, + }, + }, + }, + }, + }, + }, + }, + } +} + +func normalizeJobName(name string) string { + return strings.ReplaceAll(strings.ToLower(name), "_", "-") +} diff --git a/internal/database/pvc.go b/internal/database/pvc.go new file mode 100644 index 0000000..971157c --- /dev/null +++ b/internal/database/pvc.go @@ -0,0 +1,98 @@ +package database + +import ( + "context" + "fmt" + "time" + + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/api/resource" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/types" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" + logf "sigs.k8s.io/controller-runtime/pkg/log" + + datav1 "github.com/kubefold/operator/api/v1" + "github.com/kubefold/operator/internal/shared" +) + +const ReconcileInterval = 10 * time.Second + +type VolumeReconciler interface { + Ensure(ctx context.Context, database *datav1.ProteinDatabase) (*corev1.PersistentVolumeClaim, *ctrl.Result, error) +} + +type volumeReconciler struct { + client client.Client + scheme *runtime.Scheme + sizer Sizer +} + +func NewVolumeReconciler(c client.Client, scheme *runtime.Scheme, sizer Sizer) VolumeReconciler { + return &volumeReconciler{client: c, scheme: scheme, sizer: sizer} +} + +func (v *volumeReconciler) Ensure(ctx context.Context, database *datav1.ProteinDatabase) (*corev1.PersistentVolumeClaim, *ctrl.Result, error) { + log := logf.FromContext(ctx) + pvcName := shared.DatabasePVCName(database.Name) + + pvc := &corev1.PersistentVolumeClaim{} + err := v.client.Get(ctx, types.NamespacedName{Name: pvcName, Namespace: database.Namespace}, pvc) + if errors.IsNotFound(err) { + pvc, err = v.create(ctx, database, pvcName) + if err != nil { + log.Error(err, "Failed to create PVC") + return nil, nil, err + } + log.Info("Created new PVC", "pvcName", pvc.Name) + } else if err != nil { + log.Error(err, "Failed to get PVC") + return nil, nil, err + } + + if pvc.Status.Phase != corev1.ClaimBound { + log.Info("PVC is not bound yet", "pvcName", pvc.Name, "phase", pvc.Status.Phase) + result := ctrl.Result{Requeue: true, RequeueAfter: ReconcileInterval} + return pvc, &result, nil + } + return pvc, nil, nil +} + +func (v *volumeReconciler) create(ctx context.Context, database *datav1.ProteinDatabase, pvcName string) (*corev1.PersistentVolumeClaim, error) { + labels := shared.MergeLabels( + database.Spec.Volume.Labels, + shared.DatabaseLabels(database.Name, "proteindatabase"), + ) + + pvc := &corev1.PersistentVolumeClaim{ + ObjectMeta: metav1.ObjectMeta{ + Name: pvcName, + Namespace: database.Namespace, + Labels: labels, + Annotations: database.Spec.Volume.Annotations, + }, + Spec: corev1.PersistentVolumeClaimSpec{ + AccessModes: []corev1.PersistentVolumeAccessMode{corev1.ReadWriteMany}, + StorageClassName: database.Spec.Volume.StorageClassName, + Resources: corev1.VolumeResourceRequirements{ + Requests: corev1.ResourceList{ + corev1.ResourceStorage: resource.MustParse(fmt.Sprintf("%dGi", v.sizer.RequestedGigabytes(database))), + }, + }, + }, + } + if database.Spec.Volume.Selector != nil { + pvc.Spec.Selector = database.Spec.Volume.Selector + } + if err := controllerutil.SetControllerReference(database, pvc, v.scheme); err != nil { + return nil, err + } + if err := v.client.Create(ctx, pvc); err != nil { + return nil, err + } + return pvc, nil +} diff --git a/internal/database/size.go b/internal/database/size.go new file mode 100644 index 0000000..7c2cdee --- /dev/null +++ b/internal/database/size.go @@ -0,0 +1,28 @@ +package database + +import ( + datav1 "github.com/kubefold/operator/api/v1" +) + +const sizeBufferGigabytes = int64(2) + +type Sizer interface { + RequestedGigabytes(database *datav1.ProteinDatabase) int64 +} + +type sizer struct { + enumerator DatasetEnumerator + buffer int64 +} + +func NewSizer(enumerator DatasetEnumerator) Sizer { + return &sizer{enumerator: enumerator, buffer: sizeBufferGigabytes} +} + +func (s *sizer) RequestedGigabytes(database *datav1.ProteinDatabase) int64 { + var totalBytes int64 + for _, dataset := range s.enumerator.FromSpec(database) { + totalBytes += dataset.Size() + } + return totalBytes/1024/1024/1024 + s.buffer +} diff --git a/internal/database/size_test.go b/internal/database/size_test.go new file mode 100644 index 0000000..ab5be63 --- /dev/null +++ b/internal/database/size_test.go @@ -0,0 +1,28 @@ +package database + +import ( + "testing" + + datav1 "github.com/kubefold/operator/api/v1" +) + +func TestSizerAddsBuffer(t *testing.T) { + sizer := NewSizer(NewDatasetEnumerator()) + empty := &datav1.ProteinDatabase{} + if got := sizer.RequestedGigabytes(empty); got != sizeBufferGigabytes { + t.Fatalf("expected buffer-only size %d for empty spec, got %d", sizeBufferGigabytes, got) + } +} + +func TestSizerSumsEnabledDatasets(t *testing.T) { + sizer := NewSizer(NewDatasetEnumerator()) + database := &datav1.ProteinDatabase{ + Spec: datav1.ProteinDatabaseSpec{ + Datasets: datav1.ProteinDatabaseDatasetSelection{BFD: true}, + }, + } + got := sizer.RequestedGigabytes(database) + if got < sizeBufferGigabytes { + t.Fatalf("size with one dataset must include buffer: %d", got) + } +} diff --git a/internal/observer/aggregate.go b/internal/observer/aggregate.go new file mode 100644 index 0000000..7295084 --- /dev/null +++ b/internal/observer/aggregate.go @@ -0,0 +1,40 @@ +package observer + +import ( + datav1 "github.com/kubefold/operator/api/v1" + "github.com/kubefold/operator/internal/database" + "github.com/kubefold/operator/internal/util" +) + +type MetricsAggregator interface { + Aggregate(status *datav1.ProteinDatabaseStatus) +} + +type metricsAggregator struct { + enumerator database.DatasetEnumerator +} + +func NewMetricsAggregator(enumerator database.DatasetEnumerator) MetricsAggregator { + return &metricsAggregator{enumerator: enumerator} +} + +func (a *metricsAggregator) Aggregate(status *datav1.ProteinDatabaseStatus) { + var size, totalSize, speedBytesPerSecond int64 + for _, dataset := range a.enumerator.All() { + slot, ok := datasetSlots[dataset] + if !ok { + continue + } + progress := slot(status) + size += progress.Size + totalSize += progress.TotalSize + if progress.DeltaDuration != nil { + speedBytesPerSecond += util.CalculateDownloadSpeed(progress.Delta, progress.DeltaDuration.Duration) + } + } + status.Size = util.FormatSize(size) + status.TotalSize = util.FormatSize(totalSize) + status.Progress = util.FormatPercentage(size, totalSize) + status.DownloadSpeed = util.FormatSpeed(speedBytesPerSecond) + status.DownloadStatus = classifyStatus(size, totalSize) +} diff --git a/internal/observer/entry.go b/internal/observer/entry.go new file mode 100644 index 0000000..62fd2f3 --- /dev/null +++ b/internal/observer/entry.go @@ -0,0 +1,32 @@ +package observer + +import ( + "encoding/json" + "time" + + downloaderTypes "github.com/kubefold/downloader/pkg/types" +) + +type LogEntry struct { + DatasetName string `json:"dataset"` + Type string `json:"type"` + Msg string `json:"msg"` + Size int64 `json:"size"` + Total int64 `json:"total"` + Unit string `json:"unit"` + Hash string `json:"hash,omitempty"` + Level string `json:"level"` + Time time.Time `json:"time,omitempty"` +} + +func (e *LogEntry) Dataset() downloaderTypes.Dataset { + return downloaderTypes.Dataset(e.DatasetName) +} + +func ParseLogEntry(line []byte) (LogEntry, bool) { + var entry LogEntry + if err := json.Unmarshal(line, &entry); err != nil { + return LogEntry{}, false + } + return entry, true +} diff --git a/internal/observer/entry_test.go b/internal/observer/entry_test.go new file mode 100644 index 0000000..abb61d8 --- /dev/null +++ b/internal/observer/entry_test.go @@ -0,0 +1,28 @@ +package observer + +import ( + "testing" +) + +func TestParseLogEntryValid(t *testing.T) { + line := []byte(`{"dataset":"bfd","type":"progress","size":1024,"total":2048,"unit":"B","level":"info"}`) + entry, ok := ParseLogEntry(line) + if !ok { + t.Fatal("expected ok=true for valid JSON") + } + if entry.DatasetName != "bfd" { + t.Fatalf("DatasetName = %q, want bfd", entry.DatasetName) + } + if entry.Size != 1024 || entry.Total != 2048 { + t.Fatalf("size/total mismatch: %d/%d", entry.Size, entry.Total) + } +} + +func TestParseLogEntryInvalid(t *testing.T) { + if _, ok := ParseLogEntry([]byte("not json")); ok { + t.Fatal("expected ok=false for invalid JSON") + } + if _, ok := ParseLogEntry([]byte("")); ok { + t.Fatal("expected ok=false for empty input") + } +} diff --git a/internal/observer/pods.go b/internal/observer/pods.go new file mode 100644 index 0000000..9efa5ac --- /dev/null +++ b/internal/observer/pods.go @@ -0,0 +1,62 @@ +package observer + +import ( + "bufio" + "context" + "fmt" + "io" + + corev1 "k8s.io/api/core/v1" + "k8s.io/client-go/kubernetes" + + "github.com/kubefold/operator/internal/util" +) + +const downloaderContainer = "downloader" + +type PodLogReader interface { + StreamLines(ctx context.Context, pod corev1.Pod, tailLines int64) ([]LogEntry, error) +} + +type podLogReader struct { + kubeClient kubernetes.Interface +} + +func NewPodLogReader(kubeClient kubernetes.Interface) PodLogReader { + return &podLogReader{kubeClient: kubeClient} +} + +func (r *podLogReader) StreamLines(ctx context.Context, pod corev1.Pod, tailLines int64) ([]LogEntry, error) { + options := corev1.PodLogOptions{ + Container: downloaderContainer, + TailLines: util.Int64Ptr(tailLines), + Follow: false, + } + request := r.kubeClient.CoreV1().Pods(pod.Namespace).GetLogs(pod.Name, &options) + stream, err := request.Stream(ctx) + if err != nil { + return nil, fmt.Errorf("failed to open pod log stream: %w", err) + } + defer func() { + if closeErr := stream.Close(); closeErr != nil { + log.Error(closeErr, "failed to close pod log stream", "pod", pod.Name) + } + }() + + reader := bufio.NewReader(stream) + entries := make([]LogEntry, 0) + for { + line, err := reader.ReadString('\n') + if line != "" { + if entry, ok := ParseLogEntry([]byte(line)); ok { + entries = append(entries, entry) + } + } + if err != nil { + if err == io.EOF { + return entries, nil + } + return nil, fmt.Errorf("error reading pod logs: %w", err) + } + } +} diff --git a/internal/observer/progress.go b/internal/observer/progress.go new file mode 100644 index 0000000..5fce1a6 --- /dev/null +++ b/internal/observer/progress.go @@ -0,0 +1,96 @@ +package observer + +import ( + downloaderTypes "github.com/kubefold/downloader/pkg/types" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + datav1 "github.com/kubefold/operator/api/v1" + "github.com/kubefold/operator/internal/util" +) + +type datasetSlot func(*datav1.ProteinDatabaseStatus) *datav1.ProteinDatabaseDatasetDownloadProgress + +var datasetSlots = map[downloaderTypes.Dataset]datasetSlot{ + downloaderTypes.DatasetMGYClusters: func(s *datav1.ProteinDatabaseStatus) *datav1.ProteinDatabaseDatasetDownloadProgress { + return &s.Datasets.MGYClusters + }, + downloaderTypes.DatasetBFD: func(s *datav1.ProteinDatabaseStatus) *datav1.ProteinDatabaseDatasetDownloadProgress { + return &s.Datasets.BFD + }, + downloaderTypes.DatasetUniRef90: func(s *datav1.ProteinDatabaseStatus) *datav1.ProteinDatabaseDatasetDownloadProgress { + return &s.Datasets.UniRef90 + }, + downloaderTypes.DatasetUniProt: func(s *datav1.ProteinDatabaseStatus) *datav1.ProteinDatabaseDatasetDownloadProgress { + return &s.Datasets.UniProt + }, + downloaderTypes.DatasetPDB: func(s *datav1.ProteinDatabaseStatus) *datav1.ProteinDatabaseDatasetDownloadProgress { + return &s.Datasets.PDB + }, + downloaderTypes.DatasetPDBSeqReq: func(s *datav1.ProteinDatabaseStatus) *datav1.ProteinDatabaseDatasetDownloadProgress { + return &s.Datasets.PDBSeqReq + }, + downloaderTypes.DatasetRNACentral: func(s *datav1.ProteinDatabaseStatus) *datav1.ProteinDatabaseDatasetDownloadProgress { + return &s.Datasets.RNACentral + }, + downloaderTypes.DatasetNT: func(s *datav1.ProteinDatabaseStatus) *datav1.ProteinDatabaseDatasetDownloadProgress { + return &s.Datasets.NT + }, + downloaderTypes.DatasetRFam: func(s *datav1.ProteinDatabaseStatus) *datav1.ProteinDatabaseDatasetDownloadProgress { + return &s.Datasets.RFam + }, +} + +type ProgressUpdater interface { + Update(status *datav1.ProteinDatabaseStatus, entry LogEntry) +} + +type progressUpdater struct{} + +func NewProgressUpdater() ProgressUpdater { + return &progressUpdater{} +} + +func (u *progressUpdater) Update(status *datav1.ProteinDatabaseStatus, entry LogEntry) { + if entry.DatasetName == "" { + return + } + slot, ok := datasetSlots[entry.Dataset()] + if !ok { + return + } + progress := computeProgress(slot(status), entry) + *slot(status) = progress +} + +func computeProgress(previous *datav1.ProteinDatabaseDatasetDownloadProgress, entry LogEntry) datav1.ProteinDatabaseDatasetDownloadProgress { + totalSize := entry.Dataset().Size() + progress := datav1.ProteinDatabaseDatasetDownloadProgress{ + DownloadStatus: classifyStatus(entry.Size, totalSize), + Size: entry.Size, + TotalSize: totalSize, + Progress: util.FormatPercentage(entry.Size, totalSize), + LastUpdate: &metav1.Time{Time: entry.Time}, + } + if previous.LastUpdate != nil { + progress.Delta = progress.Size - previous.Size + progress.DeltaDuration = &metav1.Duration{Duration: progress.LastUpdate.Sub(previous.LastUpdate.Time)} + progress.DownloadSpeed = util.FormatSpeed(util.CalculateDownloadSpeed(progress.Delta, progress.DeltaDuration.Duration)) + } + if progress.DownloadStatus == datav1.ProteinDatabaseDownloadStatusCompleted { + progress.Delta = 0 + progress.DeltaDuration = nil + progress.DownloadSpeed = "" + } + return progress +} + +func classifyStatus(size, totalSize int64) datav1.ProteinDatabaseDownloadStatus { + switch { + case size == 0: + return datav1.ProteinDatabaseDownloadStatusNotStarted + case size < totalSize: + return datav1.ProteinDatabaseDownloadStatusDownloading + default: + return datav1.ProteinDatabaseDownloadStatusCompleted + } +} diff --git a/internal/observer/progress_test.go b/internal/observer/progress_test.go new file mode 100644 index 0000000..bdb3ea0 --- /dev/null +++ b/internal/observer/progress_test.go @@ -0,0 +1,60 @@ +package observer + +import ( + "testing" + + downloaderTypes "github.com/kubefold/downloader/pkg/types" + + datav1 "github.com/kubefold/operator/api/v1" + "github.com/kubefold/operator/internal/database" +) + +func TestDispatchTableCoversAllDatasets(t *testing.T) { + enumerator := database.NewDatasetEnumerator() + for _, dataset := range enumerator.All() { + if _, ok := datasetSlots[dataset]; !ok { + t.Fatalf("dispatch table missing slot for dataset %q", dataset) + } + } + if len(datasetSlots) != len(enumerator.All()) { + t.Fatalf("dispatch table has %d entries, enumerator declares %d", len(datasetSlots), len(enumerator.All())) + } +} + +func TestProgressUpdaterIgnoresUnknownDataset(t *testing.T) { + updater := NewProgressUpdater() + status := &datav1.ProteinDatabaseStatus{} + updater.Update(status, LogEntry{DatasetName: "unknown_dataset", Size: 100}) +} + +func TestProgressUpdaterEmptyDatasetName(t *testing.T) { + updater := NewProgressUpdater() + status := &datav1.ProteinDatabaseStatus{} + updater.Update(status, LogEntry{}) +} + +func TestProgressUpdaterClassifiesStatus(t *testing.T) { + updater := NewProgressUpdater() + status := &datav1.ProteinDatabaseStatus{} + updater.Update(status, LogEntry{ + DatasetName: string(downloaderTypes.DatasetBFD), + Size: 0, + }) + if status.Datasets.BFD.DownloadStatus != datav1.ProteinDatabaseDownloadStatusNotStarted { + t.Fatalf("expected NotStarted, got %q", status.Datasets.BFD.DownloadStatus) + } + updater.Update(status, LogEntry{ + DatasetName: string(downloaderTypes.DatasetBFD), + Size: downloaderTypes.DatasetBFD.Size() / 2, + }) + if status.Datasets.BFD.DownloadStatus != datav1.ProteinDatabaseDownloadStatusDownloading { + t.Fatalf("expected Downloading, got %q", status.Datasets.BFD.DownloadStatus) + } + updater.Update(status, LogEntry{ + DatasetName: string(downloaderTypes.DatasetBFD), + Size: downloaderTypes.DatasetBFD.Size(), + }) + if status.Datasets.BFD.DownloadStatus != datav1.ProteinDatabaseDownloadStatusCompleted { + t.Fatalf("expected Completed, got %q", status.Datasets.BFD.DownloadStatus) + } +} diff --git a/internal/observer/proteindatabase_log_observer.go b/internal/observer/proteindatabase_log_observer.go index 1787b10..4f321ba 100644 --- a/internal/observer/proteindatabase_log_observer.go +++ b/internal/observer/proteindatabase_log_observer.go @@ -1,24 +1,26 @@ package observer import ( - "bufio" "context" - "encoding/json" "fmt" - "io" + "sync" "time" - "github.com/kubefold/operator/internal/util" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - - downloaderTypes "github.com/kubefold/downloader/pkg/types" - datav1 "github.com/kubefold/operator/api/v1" corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/labels" - "k8s.io/apimachinery/pkg/types" "k8s.io/client-go/kubernetes" "sigs.k8s.io/controller-runtime/pkg/client" logf "sigs.k8s.io/controller-runtime/pkg/log" + + datav1 "github.com/kubefold/operator/api/v1" + "github.com/kubefold/operator/internal/database" + "github.com/kubefold/operator/internal/shared" +) + +const ( + defaultInterval = 500 * time.Millisecond + defaultTailLines = int64(100) + databaseLabelKey = shared.LabelDatabase ) var log = logf.Log.WithName("proteindatabase_log_observer") @@ -27,84 +29,90 @@ type LogObserver interface { Start(ctx context.Context) error } -type logObserver struct { - client client.Client - kubeClient kubernetes.Interface - stopCh chan struct{} +type Option func(*logObserver) + +func WithInterval(interval time.Duration) Option { + return func(o *logObserver) { + o.interval = interval + } } -type LogEntry struct { - DatasetName string `json:"dataset"` - Type string `json:"type"` - Msg string `json:"msg"` - Size int64 `json:"size"` - Total int64 `json:"total"` - Unit string `json:"unit"` - Hash string `json:"hash,omitempty"` - Level string `json:"level"` - Time time.Time `json:"time,omitempty"` +func WithTailLines(tailLines int64) Option { + return func(o *logObserver) { + o.tailLines = tailLines + } } -func (l *LogEntry) Dataset() downloaderTypes.Dataset { - return downloaderTypes.Dataset(l.DatasetName) +type logObserver struct { + client client.Client + kubeClient kubernetes.Interface + reader PodLogReader + progress ProgressUpdater + publisher StatusPublisher + interval time.Duration + tailLines int64 + waitGroup sync.WaitGroup } -func NewLogObserver(c client.Client, kubeClient kubernetes.Interface) LogObserver { - return &logObserver{ +func NewLogObserver(c client.Client, kubeClient kubernetes.Interface, options ...Option) LogObserver { + enumerator := database.NewDatasetEnumerator() + o := &logObserver{ client: c, kubeClient: kubeClient, - stopCh: make(chan struct{}), + reader: NewPodLogReader(kubeClient), + progress: NewProgressUpdater(), + publisher: NewStatusPublisher(c, NewMetricsAggregator(enumerator)), + interval: defaultInterval, + tailLines: defaultTailLines, + } + for _, apply := range options { + apply(o) } + return o } func (o *logObserver) Start(ctx context.Context) error { - go o.run(ctx) - return nil -} - -func (o *logObserver) Stop() { - close(o.stopCh) -} - -func (o *logObserver) run(ctx context.Context) { - ticker := time.NewTicker(500 * time.Millisecond) + log.Info("Log observer started", "interval", o.interval) + ticker := time.NewTicker(o.interval) defer ticker.Stop() for { select { case <-ticker.C: - if err := o.updateProteinDatabaseStatus(ctx); err != nil { - log.Error(err, "Error updating ProteinDatabase status") - } - case <-o.stopCh: - log.Info("Log observer stopped") - return + o.tick(ctx) case <-ctx.Done(): log.Info("Context done, stopping log observer") - return + o.waitGroup.Wait() + return nil } } } -func (o *logObserver) updateProteinDatabaseStatus(ctx context.Context) error { - var proteinDatabases datav1.ProteinDatabaseList - if err := o.client.List(ctx, &proteinDatabases); err != nil { - return fmt.Errorf("failed to list protein databases: %w", err) +func (o *logObserver) tick(ctx context.Context) { + o.waitGroup.Add(1) + defer o.waitGroup.Done() + if err := o.updateAllStatuses(ctx); err != nil { + log.Error(err, "Error updating ProteinDatabase statuses") } +} - for i := range proteinDatabases.Items { - pd := &proteinDatabases.Items[i] - if err := o.processProteinDatabase(ctx, pd); err != nil { - log.Error(err, "Error processing ProteinDatabase", "name", pd.Name, "namespace", pd.Namespace) +func (o *logObserver) updateAllStatuses(ctx context.Context) error { + databases := &datav1.ProteinDatabaseList{} + if err := o.client.List(ctx, databases); err != nil { + return fmt.Errorf("failed to list protein databases: %w", err) + } + for i := range databases.Items { + databasePointer := &databases.Items[i] + if err := o.processDatabase(ctx, databasePointer); err != nil { + log.Error(err, "Error processing ProteinDatabase", "name", databasePointer.Name, "namespace", databasePointer.Namespace) continue } } - return nil } -func (o *logObserver) processProteinDatabase(ctx context.Context, pd *datav1.ProteinDatabase) error { - pods, err := o.findDownloadPods(ctx, pd) +func (o *logObserver) processDatabase(ctx context.Context, db *datav1.ProteinDatabase) error { + pods, err := o.findDownloadPods(ctx, db) if err != nil { return fmt.Errorf("failed to find download pods: %w", err) } @@ -112,24 +120,27 @@ func (o *logObserver) processProteinDatabase(ctx context.Context, pd *datav1.Pro return nil } - proteinDatabaseStatus := pd.Status.DeepCopy() + newStatus := db.Status.DeepCopy() for _, pod := range pods { - if err := o.processPodsLogs(ctx, pod, proteinDatabaseStatus); err != nil { - log.Error(err, "Error processing pod logs", "pod", pod.Name, "namespace", pod.Namespace) + entries, err := o.reader.StreamLines(ctx, pod, o.tailLines) + if err != nil { + log.Error(err, "Error reading pod logs", "pod", pod.Name) continue } + for _, entry := range entries { + o.progress.Update(newStatus, entry) + } } - - return o.updateStatus(ctx, pd, proteinDatabaseStatus) + return o.publisher.Publish(ctx, db, newStatus) } -func (o *logObserver) findDownloadPods(ctx context.Context, pd *datav1.ProteinDatabase) ([]corev1.Pod, error) { - var podList corev1.PodList - err := o.client.List(ctx, &podList, +func (o *logObserver) findDownloadPods(ctx context.Context, db *datav1.ProteinDatabase) ([]corev1.Pod, error) { + podList := &corev1.PodList{} + err := o.client.List(ctx, podList, &client.ListOptions{ - Namespace: pd.Namespace, + Namespace: db.Namespace, LabelSelector: labels.SelectorFromSet(map[string]string{ - "data.kubefold.io/database": pd.Name, + databaseLabelKey: db.Name, }), }) if err != nil { @@ -137,227 +148,3 @@ func (o *logObserver) findDownloadPods(ctx context.Context, pd *datav1.ProteinDa } return podList.Items, nil } - -//nolint:gocyclo -func (o *logObserver) processPodsLogs(ctx context.Context, pod corev1.Pod, proteinDatabaseStatus *datav1.ProteinDatabaseStatus) error { - podLogOpts := corev1.PodLogOptions{ - Container: "downloader", - TailLines: util.Int64Ptr(100), - Follow: false, - } - - req := o.kubeClient.CoreV1().Pods(pod.Namespace).GetLogs(pod.Name, &podLogOpts) - podLogs, err := req.Stream(ctx) - if err != nil { - return fmt.Errorf("failed to open pod log stream: %w", err) - } - //nolint:errcheck - defer podLogs.Close() - - reader := bufio.NewReader(podLogs) - for { - line, err := reader.ReadString('\n') - if err != nil { - if err == io.EOF { - break - } - return fmt.Errorf("error reading pod logs: %w", err) - } - - var logEntry LogEntry - if err := json.Unmarshal([]byte(line), &logEntry); err != nil { - continue - } - - if logEntry.DatasetName != "" { - var downloadStatus datav1.ProteinDatabaseDownloadStatus - if logEntry.Size == 0 { - downloadStatus = datav1.ProteinDatabaseDownloadStatusNotStarted - } else if logEntry.Size > 0 && logEntry.Size < logEntry.Dataset().Size() { - downloadStatus = datav1.ProteinDatabaseDownloadStatusDownloading - } else if logEntry.Size == logEntry.Dataset().Size() { - downloadStatus = datav1.ProteinDatabaseDownloadStatusCompleted - } - - progress := datav1.ProteinDatabaseDatasetDownloadProgress{ - DownloadStatus: downloadStatus, - Size: logEntry.Size, - TotalSize: logEntry.Dataset().Size(), - Progress: util.FormatPercentage(logEntry.Size, logEntry.Dataset().Size()), - LastUpdate: &metav1.Time{Time: logEntry.Time}, - } - - switch logEntry.Dataset() { - case downloaderTypes.DatasetRFam: - if proteinDatabaseStatus.Datasets.RFam.LastUpdate != nil { - progress.Delta = progress.Size - proteinDatabaseStatus.Datasets.RFam.Size - progress.DeltaDuration = &metav1.Duration{Duration: progress.LastUpdate.Sub(proteinDatabaseStatus.Datasets.RFam.LastUpdate.Time)} - progress.DownloadSpeed = util.FormatSpeed(util.CalculateDownloadSpeed(progress.Delta, progress.DeltaDuration.Duration)) - } - if progress.DownloadStatus == datav1.ProteinDatabaseDownloadStatusCompleted { - progress.Delta = 0 - progress.DeltaDuration = nil - progress.DownloadSpeed = "" - } - proteinDatabaseStatus.Datasets.RFam = progress - case downloaderTypes.DatasetBFD: - if proteinDatabaseStatus.Datasets.BFD.LastUpdate != nil { - progress.Delta = progress.Size - proteinDatabaseStatus.Datasets.BFD.Size - progress.DeltaDuration = &metav1.Duration{Duration: progress.LastUpdate.Sub(proteinDatabaseStatus.Datasets.BFD.LastUpdate.Time)} - progress.DownloadSpeed = util.FormatSpeed(util.CalculateDownloadSpeed(progress.Delta, progress.DeltaDuration.Duration)) - } - if progress.DownloadStatus == datav1.ProteinDatabaseDownloadStatusCompleted { - progress.Delta = 0 - progress.DeltaDuration = nil - progress.DownloadSpeed = "" - } - proteinDatabaseStatus.Datasets.BFD = progress - case downloaderTypes.DatasetUniProt: - if proteinDatabaseStatus.Datasets.UniProt.LastUpdate != nil { - progress.Delta = progress.Size - proteinDatabaseStatus.Datasets.UniProt.Size - progress.DeltaDuration = &metav1.Duration{Duration: progress.LastUpdate.Sub(proteinDatabaseStatus.Datasets.UniProt.LastUpdate.Time)} - progress.DownloadSpeed = util.FormatSpeed(util.CalculateDownloadSpeed(progress.Delta, progress.DeltaDuration.Duration)) - } - if progress.DownloadStatus == datav1.ProteinDatabaseDownloadStatusCompleted { - progress.Delta = 0 - progress.DeltaDuration = nil - progress.DownloadSpeed = "" - } - proteinDatabaseStatus.Datasets.UniProt = progress - case downloaderTypes.DatasetUniRef90: - if proteinDatabaseStatus.Datasets.UniRef90.LastUpdate != nil { - progress.Delta = progress.Size - proteinDatabaseStatus.Datasets.UniRef90.Size - progress.DeltaDuration = &metav1.Duration{Duration: progress.LastUpdate.Sub(proteinDatabaseStatus.Datasets.UniRef90.LastUpdate.Time)} - progress.DownloadSpeed = util.FormatSpeed(util.CalculateDownloadSpeed(progress.Delta, progress.DeltaDuration.Duration)) - } - if progress.DownloadStatus == datav1.ProteinDatabaseDownloadStatusCompleted { - progress.Delta = 0 - progress.DeltaDuration = nil - progress.DownloadSpeed = "" - } - proteinDatabaseStatus.Datasets.UniRef90 = progress - case downloaderTypes.DatasetRNACentral: - if proteinDatabaseStatus.Datasets.RNACentral.LastUpdate != nil { - progress.Delta = progress.Size - proteinDatabaseStatus.Datasets.RNACentral.Size - progress.DeltaDuration = &metav1.Duration{Duration: progress.LastUpdate.Sub(proteinDatabaseStatus.Datasets.RNACentral.LastUpdate.Time)} - progress.DownloadSpeed = util.FormatSpeed(util.CalculateDownloadSpeed(progress.Delta, progress.DeltaDuration.Duration)) - } - if progress.DownloadStatus == datav1.ProteinDatabaseDownloadStatusCompleted { - progress.Delta = 0 - progress.DeltaDuration = nil - progress.DownloadSpeed = "" - } - proteinDatabaseStatus.Datasets.RNACentral = progress - case downloaderTypes.DatasetPDB: - if proteinDatabaseStatus.Datasets.PDB.LastUpdate != nil { - progress.Delta = progress.Size - proteinDatabaseStatus.Datasets.PDB.Size - progress.DeltaDuration = &metav1.Duration{Duration: progress.LastUpdate.Sub(proteinDatabaseStatus.Datasets.PDB.LastUpdate.Time)} - progress.DownloadSpeed = util.FormatSpeed(util.CalculateDownloadSpeed(progress.Delta, progress.DeltaDuration.Duration)) - } - if progress.DownloadStatus == datav1.ProteinDatabaseDownloadStatusCompleted { - progress.Delta = 0 - progress.DeltaDuration = nil - progress.DownloadSpeed = "" - } - proteinDatabaseStatus.Datasets.PDB = progress - case downloaderTypes.DatasetPDBSeqReq: - if proteinDatabaseStatus.Datasets.PDBSeqReq.LastUpdate != nil { - progress.Delta = progress.Size - proteinDatabaseStatus.Datasets.PDBSeqReq.Size - progress.DeltaDuration = &metav1.Duration{Duration: progress.LastUpdate.Sub(proteinDatabaseStatus.Datasets.PDBSeqReq.LastUpdate.Time)} - progress.DownloadSpeed = util.FormatSpeed(util.CalculateDownloadSpeed(progress.Delta, progress.DeltaDuration.Duration)) - } - if progress.DownloadStatus == datav1.ProteinDatabaseDownloadStatusCompleted { - progress.Delta = 0 - progress.DeltaDuration = nil - progress.DownloadSpeed = "" - } - proteinDatabaseStatus.Datasets.PDBSeqReq = progress - case downloaderTypes.DatasetNT: - if proteinDatabaseStatus.Datasets.NT.LastUpdate != nil { - progress.Delta = progress.Size - proteinDatabaseStatus.Datasets.NT.Size - progress.DeltaDuration = &metav1.Duration{Duration: progress.LastUpdate.Sub(proteinDatabaseStatus.Datasets.NT.LastUpdate.Time)} - progress.DownloadSpeed = util.FormatSpeed(util.CalculateDownloadSpeed(progress.Delta, progress.DeltaDuration.Duration)) - } - if progress.DownloadStatus == datav1.ProteinDatabaseDownloadStatusCompleted { - progress.Delta = 0 - progress.DeltaDuration = nil - progress.DownloadSpeed = "" - } - proteinDatabaseStatus.Datasets.NT = progress - case downloaderTypes.DatasetMGYClusters: - if proteinDatabaseStatus.Datasets.MGYClusters.LastUpdate != nil { - progress.Delta = progress.Size - proteinDatabaseStatus.Datasets.MGYClusters.Size - progress.DeltaDuration = &metav1.Duration{Duration: progress.LastUpdate.Sub(proteinDatabaseStatus.Datasets.MGYClusters.LastUpdate.Time)} - progress.DownloadSpeed = util.FormatSpeed(util.CalculateDownloadSpeed(progress.Delta, progress.DeltaDuration.Duration)) - } - if progress.DownloadStatus == datav1.ProteinDatabaseDownloadStatusCompleted { - progress.Delta = 0 - progress.DeltaDuration = nil - progress.DownloadSpeed = "" - } - proteinDatabaseStatus.Datasets.MGYClusters = progress - } - } - } - - return nil -} - -func (o *logObserver) aggregateDownloadMetrics(proteinDatabaseStatus *datav1.ProteinDatabaseStatus) { - var size int64 - var totalSize int64 - var delta int64 - - progresses := []datav1.ProteinDatabaseDatasetDownloadProgress{ - proteinDatabaseStatus.Datasets.RFam, - proteinDatabaseStatus.Datasets.BFD, - proteinDatabaseStatus.Datasets.UniProt, - proteinDatabaseStatus.Datasets.UniRef90, - proteinDatabaseStatus.Datasets.RNACentral, - proteinDatabaseStatus.Datasets.PDB, - proteinDatabaseStatus.Datasets.PDBSeqReq, - proteinDatabaseStatus.Datasets.NT, - proteinDatabaseStatus.Datasets.MGYClusters, - } - - for _, progress := range progresses { - size += progress.Size - totalSize += progress.TotalSize - if progress.DeltaDuration != nil && progress.DeltaDuration.Duration == time.Second { - delta += progress.Delta - } - } - - proteinDatabaseStatus.Size = util.FormatSize(size) - proteinDatabaseStatus.TotalSize = util.FormatSize(totalSize) - if totalSize > 0 { - proteinDatabaseStatus.Progress = util.FormatPercentage(size, totalSize) - } - proteinDatabaseStatus.DownloadSpeed = util.FormatSpeed(util.CalculateDownloadSpeed(delta, time.Second)) - - if size == 0 { - proteinDatabaseStatus.DownloadStatus = datav1.ProteinDatabaseDownloadStatusNotStarted - } else if size > 0 && size < totalSize { - proteinDatabaseStatus.DownloadStatus = datav1.ProteinDatabaseDownloadStatusDownloading - } else if size == totalSize { - proteinDatabaseStatus.DownloadStatus = datav1.ProteinDatabaseDownloadStatusCompleted - } -} - -func (o *logObserver) updateStatus(ctx context.Context, pd *datav1.ProteinDatabase, proteinDatabaseStatus *datav1.ProteinDatabaseStatus) error { - proteinDatabase := &datav1.ProteinDatabase{} - if err := o.client.Get(ctx, types.NamespacedName{Name: pd.Name, Namespace: pd.Namespace}, proteinDatabase); err != nil { - return fmt.Errorf("failed to get latest ProteinDatabase: %w", err) - } - - o.aggregateDownloadMetrics(proteinDatabaseStatus) - - proteinDatabase.Status = *proteinDatabaseStatus - proteinDatabase.Status.LastUpdate = util.GetNow() - - if err := o.client.Status().Update(ctx, proteinDatabase); err != nil { - return fmt.Errorf("failed to update ProteinDatabase status: %w", err) - } - - return nil -} diff --git a/internal/observer/status.go b/internal/observer/status.go new file mode 100644 index 0000000..c8d5c5d --- /dev/null +++ b/internal/observer/status.go @@ -0,0 +1,41 @@ +package observer + +import ( + "context" + "fmt" + + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/client" + + datav1 "github.com/kubefold/operator/api/v1" + "github.com/kubefold/operator/internal/shared" + "github.com/kubefold/operator/internal/util" +) + +type StatusPublisher interface { + Publish(ctx context.Context, database *datav1.ProteinDatabase, newStatus *datav1.ProteinDatabaseStatus) error +} + +type statusPublisher struct { + client client.Client + aggregator MetricsAggregator +} + +func NewStatusPublisher(c client.Client, aggregator MetricsAggregator) StatusPublisher { + return &statusPublisher{client: c, aggregator: aggregator} +} + +func (p *statusPublisher) Publish(ctx context.Context, database *datav1.ProteinDatabase, newStatus *datav1.ProteinDatabaseStatus) error { + return shared.RetryOnConflict(ctx, func() error { + latest := &datav1.ProteinDatabase{} + if err := p.client.Get(ctx, types.NamespacedName{Name: database.Name, Namespace: database.Namespace}, latest); err != nil { + return fmt.Errorf("failed to get latest ProteinDatabase: %w", err) + } + p.aggregator.Aggregate(newStatus) + volumeName := latest.Status.VolumeName + latest.Status = *newStatus + latest.Status.VolumeName = volumeName + latest.Status.LastUpdate = util.GetNow() + return p.client.Status().Update(ctx, latest) + }) +} diff --git a/internal/prediction/cleanup.go b/internal/prediction/cleanup.go new file mode 100644 index 0000000..f83f2d1 --- /dev/null +++ b/internal/prediction/cleanup.go @@ -0,0 +1,77 @@ +package prediction + +import ( + "context" + + batchv1 "k8s.io/api/batch/v1" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/client" + + "github.com/kubefold/operator/internal/shared" +) + +type ResourceCleaner interface { + DeleteJobsForPrediction(ctx context.Context, name, namespace string) error + DeletePVCForPrediction(ctx context.Context, name, namespace string) error + DeleteCompletedJobs(ctx context.Context, name, namespace string) error +} + +type resourceCleaner struct { + client client.Client +} + +func NewResourceCleaner(c client.Client) ResourceCleaner { + return &resourceCleaner{client: c} +} + +func (r *resourceCleaner) DeleteJobsForPrediction(ctx context.Context, name, namespace string) error { + jobs, err := r.listJobs(ctx, name, namespace) + if err != nil { + return err + } + for i := range jobs.Items { + if err := shared.DeleteInBackground(ctx, r.client, &jobs.Items[i]); err != nil { + return err + } + } + return nil +} + +func (r *resourceCleaner) DeletePVCForPrediction(ctx context.Context, name, namespace string) error { + pvc := &corev1.PersistentVolumeClaim{} + if err := r.client.Get(ctx, types.NamespacedName{Name: shared.PredictionDataPVCName(name), Namespace: namespace}, pvc); err != nil { + if errors.IsNotFound(err) { + return nil + } + return err + } + return shared.DeleteInBackground(ctx, r.client, pvc) +} + +func (r *resourceCleaner) DeleteCompletedJobs(ctx context.Context, name, namespace string) error { + jobs, err := r.listJobs(ctx, name, namespace) + if err != nil { + return err + } + for i := range jobs.Items { + if jobs.Items[i].Status.Succeeded > 0 { + if err := shared.DeleteInBackground(ctx, r.client, &jobs.Items[i]); err != nil { + return err + } + } + } + return nil +} + +func (r *resourceCleaner) listJobs(ctx context.Context, name, namespace string) (*batchv1.JobList, error) { + jobs := &batchv1.JobList{} + if err := r.client.List(ctx, jobs, + client.InNamespace(namespace), + client.MatchingLabels{shared.LabelPrediction: name}, + ); err != nil { + return nil, err + } + return jobs, nil +} diff --git a/internal/prediction/conditions.go b/internal/prediction/conditions.go new file mode 100644 index 0000000..fda4eef --- /dev/null +++ b/internal/prediction/conditions.go @@ -0,0 +1,43 @@ +package prediction + +import ( + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + datav1 "github.com/kubefold/operator/api/v1" + "github.com/kubefold/operator/internal/shared" +) + +const ( + ConditionTypeReady = "Ready" + ConditionTypeSearchSucceeded = "SearchSucceeded" + ConditionTypePredictSucceeded = "PredictSucceeded" + ConditionTypeUploadSucceeded = "UploadSucceeded" + ConditionTypeValidationFailed = "ValidationFailed" +) + +const ( + ReasonPhaseProgressed = "PhaseProgressed" + ReasonPhaseFailed = "PhaseFailed" + ReasonValidationFailed = "ValidationFailed" + ReasonJobTimeout = "JobTimeout" + ReasonRetriesExhausted = "RetriesExhausted" + ReasonCompleted = "Completed" +) + +type ConditionsManager interface { + Set(status *datav1.ProteinConformationPredictionStatus, condition metav1.Condition) +} + +type conditionsManager struct{} + +func NewConditionsManager() ConditionsManager { + return &conditionsManager{} +} + +func (m *conditionsManager) Set(status *datav1.ProteinConformationPredictionStatus, condition metav1.Condition) { + shared.SetCondition(&status.Conditions, condition) + if updated := shared.FindCondition(status.Conditions, condition.Type); updated != nil { + stamp := updated.LastTransitionTime + status.LastTransitionTime = &stamp + } +} diff --git a/internal/prediction/controller.go b/internal/prediction/controller.go new file mode 100644 index 0000000..1eaff9b --- /dev/null +++ b/internal/prediction/controller.go @@ -0,0 +1,186 @@ +package prediction + +import ( + "context" + "os" + "slices" + "strings" + + batchv1 "k8s.io/api/batch/v1" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/client-go/tools/record" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" + logf "sigs.k8s.io/controller-runtime/pkg/log" + + datav1 "github.com/kubefold/operator/api/v1" +) + +const ( + Finalizer = "proteinconformationprediction.data.kubefold.io/finalizer" + WeightsAllowedHostsEnv = "KUBEFOLD_WEIGHTS_ALLOWED_HOSTS" +) + +type Reconciler struct { + client.Client + Scheme *runtime.Scheme + Recorder record.EventRecorder + validator SpecValidator + cleaner ResourceCleaner + status StatusWriter + router PhaseRouter +} + +func NewReconciler(c client.Client, scheme *runtime.Scheme, recorder record.EventRecorder) *Reconciler { + allowedHosts := parseAllowedHosts(os.Getenv(WeightsAllowedHostsEnv)) + validator := NewSpecValidator(allowedHosts) + pvcBuilder := NewPVCBuilder() + nodeSelector := NewNodeSelectorTranslator() + jobBuilder := NewJobBuilder(nodeSelector) + input := NewInputEncoder() + retry := NewRetryPolicy(MaxRetries) + timeout := NewTimeoutChecker() + cleaner := NewResourceCleaner(c) + conditions := NewConditionsManager() + status := NewStatusWriter(c) + router := NewPhaseRouter(c, scheme, recorder, pvcBuilder, jobBuilder, input, retry, timeout, cleaner, conditions, status) + + return &Reconciler{ + Client: c, + Scheme: scheme, + Recorder: recorder, + validator: validator, + cleaner: cleaner, + status: status, + router: router, + } +} + +// +kubebuilder:rbac:groups=data.kubefold.io,resources=proteinconformationpredictions,verbs=get;list;watch;create;update;patch;delete +// +kubebuilder:rbac:groups=data.kubefold.io,resources=proteinconformationpredictions/status,verbs=get;update;patch +// +kubebuilder:rbac:groups=data.kubefold.io,resources=proteinconformationpredictions/finalizers,verbs=update +// +kubebuilder:rbac:groups=batch,resources=jobs,verbs=get;list;watch;create;update;patch;delete +// +kubebuilder:rbac:groups="",resources=persistentvolumeclaims,verbs=get;list;watch;create;update;patch;delete + +func (r *Reconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { + log := logf.FromContext(ctx) + + prediction := &datav1.ProteinConformationPrediction{} + if err := r.Get(ctx, req.NamespacedName, prediction); err != nil { + if errors.IsNotFound(err) { + return ctrl.Result{}, nil + } + log.Error(err, "Failed to get ProteinConformationPrediction") + return ctrl.Result{}, err + } + + if !prediction.DeletionTimestamp.IsZero() { + return r.handleDeletion(ctx, prediction) + } + + if !controllerutil.ContainsFinalizer(prediction, Finalizer) { + controllerutil.AddFinalizer(prediction, Finalizer) + if err := r.Update(ctx, prediction); err != nil { + log.Error(err, "Failed to add finalizer") + return ctrl.Result{}, err + } + return ctrl.Result{Requeue: true}, nil + } + + if err := r.validator.Validate(prediction); err != nil { + log.Error(err, "Invalid spec") + return r.failValidation(ctx, prediction, err.Error()) + } + + if prediction.Status.Phase == "" { + return r.initializeStatus(ctx, prediction) + } + + return r.router.Handle(ctx, prediction) +} + +func (r *Reconciler) handleDeletion(ctx context.Context, prediction *datav1.ProteinConformationPrediction) (ctrl.Result, error) { + log := logf.FromContext(ctx) + if !controllerutil.ContainsFinalizer(prediction, Finalizer) { + return ctrl.Result{}, nil + } + if err := r.cleaner.DeleteJobsForPrediction(ctx, prediction.Name, prediction.Namespace); err != nil { + log.Error(err, "Failed to delete prediction jobs") + return ctrl.Result{}, err + } + if err := r.cleaner.DeletePVCForPrediction(ctx, prediction.Name, prediction.Namespace); err != nil { + log.Error(err, "Failed to delete prediction PVC") + return ctrl.Result{}, err + } + controllerutil.RemoveFinalizer(prediction, Finalizer) + if err := r.Update(ctx, prediction); err != nil { + log.Error(err, "Failed to remove finalizer") + return ctrl.Result{}, err + } + return ctrl.Result{}, nil +} + +func (r *Reconciler) failValidation(ctx context.Context, prediction *datav1.ProteinConformationPrediction, message string) (ctrl.Result, error) { + conditions := NewConditionsManager() + if err := r.status.Update(ctx, prediction, func(p *datav1.ProteinConformationPrediction) { + p.Status.Phase = datav1.ProteinConformationPredictionStatusPhaseFailed + p.Status.Error = message + conditions.Set(&p.Status, metav1.Condition{ + Type: ConditionTypeValidationFailed, + Status: metav1.ConditionTrue, + Reason: ReasonValidationFailed, + Message: message, + }) + conditions.Set(&p.Status, metav1.Condition{ + Type: ConditionTypeReady, + Status: metav1.ConditionFalse, + Reason: ReasonValidationFailed, + Message: message, + }) + }); err != nil { + return ctrl.Result{}, err + } + r.Recorder.Event(prediction, corev1.EventTypeWarning, ReasonValidationFailed, message) + return ctrl.Result{}, nil +} + +func (r *Reconciler) initializeStatus(ctx context.Context, prediction *datav1.ProteinConformationPrediction) (ctrl.Result, error) { + prefix := r.validator.SequencePrefix(prediction.Spec.Protein.Sequence) + if err := r.status.Update(ctx, prediction, func(p *datav1.ProteinConformationPrediction) { + p.Status.Phase = datav1.ProteinConformationPredictionStatusPhaseNotStarted + p.Status.SequencePrefix = prefix + p.Status.SearchRetryCount = 0 + p.Status.PredictRetryCount = 0 + p.Status.UploadRetryCount = 0 + }); err != nil { + return ctrl.Result{}, err + } + return ctrl.Result{Requeue: true}, nil +} + +func (r *Reconciler) SetupWithManager(mgr ctrl.Manager) error { + if r.Recorder == nil { + r.Recorder = mgr.GetEventRecorderFor("proteinconformationprediction-controller") + } + return ctrl.NewControllerManagedBy(mgr). + For(&datav1.ProteinConformationPrediction{}). + Owns(&corev1.PersistentVolumeClaim{}). + Owns(&batchv1.Job{}). + Named("proteinconformationprediction"). + Complete(r) +} + +func parseAllowedHosts(raw string) []string { + if raw == "" || raw == "*" { + return nil + } + hosts := strings.FieldsFunc(raw, func(r rune) bool { return r == ',' }) + for i, host := range hosts { + hosts[i] = strings.TrimSpace(host) + } + return slices.DeleteFunc(hosts, func(host string) bool { return host == "" }) +} diff --git a/internal/prediction/input.go b/internal/prediction/input.go new file mode 100644 index 0000000..749d97f --- /dev/null +++ b/internal/prediction/input.go @@ -0,0 +1,50 @@ +package prediction + +import ( + "encoding/base64" + "encoding/json" + "fmt" + + datav1 "github.com/kubefold/operator/api/v1" + "github.com/kubefold/operator/internal/alphafold" +) + +type InputEncoder interface { + Encode(prediction *datav1.ProteinConformationPrediction, forPredictionPhase bool) (string, error) +} + +type inputEncoder struct{} + +func NewInputEncoder() InputEncoder { + return &inputEncoder{} +} + +func (e *inputEncoder) Encode(prediction *datav1.ProteinConformationPrediction, forPredictionPhase bool) (string, error) { + input := alphafold.Input{ + Name: fmt.Sprintf("%s-%s", prediction.Namespace, prediction.Name), + Sequences: []alphafold.Sequence{ + { + Protein: alphafold.Protein{ + Sequence: prediction.Spec.Protein.Sequence, + ID: prediction.Spec.Protein.ID, + }, + }, + }, + ModelSeeds: prediction.Spec.Model.Seeds, + Dialect: "alphafold3", + Version: 1, + } + if forPredictionPhase { + empty := "" + emptyList := make([]string, 0) + input.Sequences[0].Protein.Templates = &emptyList + input.Sequences[0].Protein.UnpairedMSA = &empty + input.Sequences[0].Protein.PairedMSA = &empty + } + + encoded, err := json.Marshal(input) + if err != nil { + return "", fmt.Errorf("failed to marshal fold input: %w", err) + } + return base64.StdEncoding.EncodeToString(encoded), nil +} diff --git a/internal/prediction/job.go b/internal/prediction/job.go new file mode 100644 index 0000000..2299841 --- /dev/null +++ b/internal/prediction/job.go @@ -0,0 +1,290 @@ +package prediction + +import ( + "fmt" + "strings" + + batchv1 "k8s.io/api/batch/v1" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/resource" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + datav1 "github.com/kubefold/operator/api/v1" + "github.com/kubefold/operator/internal/shared" +) + +const ( + ManagerImage = "ghcr.io/kubefold/manager" + AlphafoldImage = "public.ecr.aws/k3x1v3b7/alphafold3:latest" + containerData = "data" + containerDB = "database" + mountPathData = "/data" + mountPathDB = "/public_databases" + weightsURLEnvVar = "WEIGHTS_URL" + weightsDownload = `set -eu; mkdir -p /data/models; wget --tries=3 --timeout=30 -O /data/models/af3.bin.zst "$WEIGHTS_URL"; unzstd /data/models/af3.bin.zst` + jobBackoffLimit = int32(2) + + containerInputPlacement = "input-placement" + containerWeightsPlacement = "weights-placement" + containerSearch = "search" + containerPredict = "predict" + containerUpload = "upload" + containerNotify = "notify" +) + +var disallowPrivilegeEscalation = false + +type JobBuilder interface { + BuildSearch(prediction *datav1.ProteinConformationPrediction, jobName, pvcName, encodedInput string) *batchv1.Job + BuildPredict(prediction *datav1.ProteinConformationPrediction, jobName, pvcName, encodedInput string) *batchv1.Job + BuildUpload(prediction *datav1.ProteinConformationPrediction, jobName, pvcName string) *batchv1.Job +} + +type jobBuilder struct { + nodeSelector NodeSelectorTranslator +} + +func NewJobBuilder(nodeSelector NodeSelectorTranslator) JobBuilder { + return &jobBuilder{nodeSelector: nodeSelector} +} + +func (b *jobBuilder) BuildSearch(prediction *datav1.ProteinConformationPrediction, jobName, pvcName, encodedInput string) *batchv1.Job { + backoffLimit := jobBackoffLimit + labels := shared.PredictionLabels(prediction.Name, "proteinconformationprediction-search", phaseSearchKey) + + job := &batchv1.Job{ + ObjectMeta: metav1.ObjectMeta{ + Name: jobName, + Namespace: prediction.Namespace, + Labels: labels, + }, + Spec: batchv1.JobSpec{ + BackoffLimit: &backoffLimit, + Template: corev1.PodTemplateSpec{ + ObjectMeta: metav1.ObjectMeta{Labels: labels}, + Spec: corev1.PodSpec{ + RestartPolicy: corev1.RestartPolicyNever, + InitContainers: []corev1.Container{inputPlacementContainer(encodedInput)}, + Containers: []corev1.Container{searchContainer(prediction.Spec.Job.Profile)}, + Volumes: predictionVolumes(pvcName, prediction.Spec.Database), + }, + }, + }, + } + + job.Spec.Template.Spec.Affinity = b.nodeSelector.ToAffinity(&prediction.Spec.Job.SearchNodeSelector) + return job +} + +func (b *jobBuilder) BuildPredict(prediction *datav1.ProteinConformationPrediction, jobName, pvcName, encodedInput string) *batchv1.Job { + backoffLimit := jobBackoffLimit + labels := shared.PredictionLabels(prediction.Name, "proteinconformationprediction-predict", phasePredictKey) + + job := &batchv1.Job{ + ObjectMeta: metav1.ObjectMeta{ + Name: jobName, + Namespace: prediction.Namespace, + Labels: labels, + }, + Spec: batchv1.JobSpec{ + BackoffLimit: &backoffLimit, + Template: corev1.PodTemplateSpec{ + ObjectMeta: metav1.ObjectMeta{Labels: labels}, + Spec: corev1.PodSpec{ + RestartPolicy: corev1.RestartPolicyNever, + InitContainers: []corev1.Container{ + inputPlacementContainer(encodedInput), + weightsPlacementContainer(prediction.Spec.Model.Weights.HTTP), + }, + Containers: []corev1.Container{predictContainer(prediction.Spec.Job.Profile)}, + Volumes: predictionVolumes(pvcName, prediction.Spec.Database), + }, + }, + }, + } + + job.Spec.Template.Spec.Affinity = b.nodeSelector.ToAffinity(&prediction.Spec.Job.PredictionNodeSelector) + return job +} + +func (b *jobBuilder) BuildUpload(prediction *datav1.ProteinConformationPrediction, jobName, pvcName string) *batchv1.Job { + backoffLimit := jobBackoffLimit + labels := shared.PredictionLabels(prediction.Name, "proteinconformationprediction-upload", phaseUploadKey) + + return &batchv1.Job{ + ObjectMeta: metav1.ObjectMeta{ + Name: jobName, + Namespace: prediction.Namespace, + Labels: labels, + }, + Spec: batchv1.JobSpec{ + BackoffLimit: &backoffLimit, + Template: corev1.PodTemplateSpec{ + ObjectMeta: metav1.ObjectMeta{Labels: labels}, + Spec: corev1.PodSpec{ + RestartPolicy: corev1.RestartPolicyNever, + Containers: []corev1.Container{ + uploadContainer(prediction), + notifyContainer(prediction), + }, + Volumes: []corev1.Volume{predictionDataVolume(pvcName)}, + }, + }, + }, + } +} + +func inputPlacementContainer(encodedInput string) corev1.Container { + return corev1.Container{ + Name: containerInputPlacement, + Image: ManagerImage, + ImagePullPolicy: corev1.PullAlways, + SecurityContext: restrictedSecurityContext(), + Env: []corev1.EnvVar{ + {Name: "INPUT_PATH", Value: "/data/af_input"}, + {Name: "OUTPUT_PATH", Value: "/data/af_output"}, + {Name: "ENCODED_INPUT", Value: encodedInput}, + }, + VolumeMounts: predictionVolumeMounts(), + } +} + +func weightsPlacementContainer(weightsURL string) corev1.Container { + return corev1.Container{ + Name: containerWeightsPlacement, + Image: ManagerImage, + ImagePullPolicy: corev1.PullAlways, + SecurityContext: restrictedSecurityContext(), + Command: []string{"sh", "-c", weightsDownload}, + Env: []corev1.EnvVar{ + {Name: weightsURLEnvVar, Value: weightsURL}, + }, + VolumeMounts: predictionVolumeMounts(), + } +} + +func searchContainer(profile datav1.ProteinConformationPredictionProfile) corev1.Container { + return corev1.Container{ + Name: containerSearch, + Image: AlphafoldImage, + ImagePullPolicy: corev1.PullAlways, + SecurityContext: restrictedSecurityContext(), + Resources: resourcesForPhase(profile, phaseSearchKey), + Command: []string{"uv"}, + Args: []string{ + "run", "python3", "run_alphafold.py", + "--json_path=/data/af_input/fold_input.json", + "--output_dir=/data/af_output", + "--model_dir=/data/models", + "--db_dir=/public_databases", + "--run_inference=false", + }, + VolumeMounts: predictionVolumeMounts(), + } +} + +func predictContainer(profile datav1.ProteinConformationPredictionProfile) corev1.Container { + requirements := resourcesForPhase(profile, phasePredictKey) + requirements.Requests["nvidia.com/gpu"] = resource.MustParse("1") + if requirements.Limits == nil { + requirements.Limits = corev1.ResourceList{} + } + requirements.Limits["nvidia.com/gpu"] = resource.MustParse("1") + return corev1.Container{ + Name: containerPredict, + Image: AlphafoldImage, + ImagePullPolicy: corev1.PullAlways, + SecurityContext: restrictedSecurityContext(), + Resources: requirements, + Command: []string{"uv"}, + Args: []string{ + "run", "python3", "run_alphafold.py", + "--json_path=/data/af_input/fold_input.json", + "--output_dir=/data/af_output", + "--model_dir=/data/models", + "--db_dir=/public_databases", + "--run_data_pipeline=false", + }, + VolumeMounts: predictionVolumeMounts(), + } +} + +func uploadContainer(prediction *datav1.ProteinConformationPrediction) corev1.Container { + return corev1.Container{ + Name: containerUpload, + Image: ManagerImage, + ImagePullPolicy: corev1.PullAlways, + SecurityContext: restrictedSecurityContext(), + Resources: resourcesForPhase(prediction.Spec.Job.Profile, phaseUploadKey), + Env: []corev1.EnvVar{ + {Name: "INPUT_PATH", Value: "/data/af_input"}, + {Name: "OUTPUT_PATH", Value: "/data/af_output"}, + {Name: "BUCKET", Value: prediction.Spec.Destination.S3.Bucket}, + {Name: "AWS_REGION", Value: prediction.Spec.Destination.S3.Region}, + }, + VolumeMounts: []corev1.VolumeMount{ + {Name: containerData, MountPath: mountPathData}, + }, + } +} + +func notifyContainer(prediction *datav1.ProteinConformationPrediction) corev1.Container { + return corev1.Container{ + Name: containerNotify, + Image: ManagerImage, + ImagePullPolicy: corev1.PullAlways, + SecurityContext: restrictedSecurityContext(), + Resources: resourcesForPhase(prediction.Spec.Job.Profile, phaseUploadKey), + Env: []corev1.EnvVar{ + {Name: "INPUT_PATH", Value: "/data/af_input"}, + {Name: "OUTPUT_PATH", Value: "/data/af_output"}, + {Name: "NOTIFICATION_PHONES", Value: strings.Join(prediction.Spec.Notifications.SMS, ",")}, + {Name: "NOTIFICATION_MESSAGE", Value: fmt.Sprintf("Protein Conformation Prediction %s in namespace %s completed. Artifacts has been uploaded to %s", prediction.Name, prediction.Namespace, prediction.Spec.Destination.S3.Bucket)}, + {Name: "AWS_REGION", Value: prediction.Spec.Destination.S3.Region}, + }, + VolumeMounts: []corev1.VolumeMount{ + {Name: containerData, MountPath: mountPathData}, + }, + } +} + +func predictionVolumes(pvcName, databaseName string) []corev1.Volume { + return []corev1.Volume{ + predictionDataVolume(pvcName), + { + Name: containerDB, + VolumeSource: corev1.VolumeSource{ + PersistentVolumeClaim: &corev1.PersistentVolumeClaimVolumeSource{ + ClaimName: shared.DatabasePVCName(databaseName), + }, + }, + }, + } +} + +func predictionDataVolume(pvcName string) corev1.Volume { + return corev1.Volume{ + Name: containerData, + VolumeSource: corev1.VolumeSource{ + PersistentVolumeClaim: &corev1.PersistentVolumeClaimVolumeSource{ + ClaimName: pvcName, + }, + }, + } +} + +func predictionVolumeMounts() []corev1.VolumeMount { + return []corev1.VolumeMount{ + {Name: containerData, MountPath: mountPathData}, + {Name: containerDB, MountPath: mountPathDB}, + } +} + +func restrictedSecurityContext() *corev1.SecurityContext { + return &corev1.SecurityContext{ + AllowPrivilegeEscalation: &disallowPrivilegeEscalation, + Capabilities: &corev1.Capabilities{ + Drop: []corev1.Capability{"ALL"}, + }, + } +} diff --git a/internal/prediction/job_test.go b/internal/prediction/job_test.go new file mode 100644 index 0000000..a19099d --- /dev/null +++ b/internal/prediction/job_test.go @@ -0,0 +1,116 @@ +package prediction + +import ( + "strings" + "testing" + + corev1 "k8s.io/api/core/v1" + + datav1 "github.com/kubefold/operator/api/v1" +) + +func makePrediction() *datav1.ProteinConformationPrediction { + return &datav1.ProteinConformationPrediction{ + Spec: datav1.ProteinConformationPredictionSpec{ + Protein: datav1.ProteinConformationPredictionProtein{ + Sequence: "ABCDE", ID: []string{"id-1"}, + }, + Database: "db", + Destination: datav1.ProteinConformationPredictionDestination{ + S3: datav1.ProteinConformationPredictionDestinationS3{Bucket: "bkt", Region: "us-east-1"}, + }, + Model: datav1.ProteinConformationPredictionModel{ + Weights: datav1.ProteinConformationPredictionModelWeights{HTTP: "https://example.com/w"}, + }, + Job: datav1.ProteinConformationPredictionJob{ + PredictionNodeSelector: corev1.NodeSelector{ + NodeSelectorTerms: []corev1.NodeSelectorTerm{ + { + MatchExpressions: []corev1.NodeSelectorRequirement{ + {Key: "zone", Operator: corev1.NodeSelectorOpIn, Values: []string{"a", "b"}}, + }, + }, + }, + }, + }, + }, + } +} + +func TestBuildPredictWeightsViaEnvVar(t *testing.T) { + builder := NewJobBuilder(NewNodeSelectorTranslator()) + prediction := makePrediction() + + job := builder.BuildPredict(prediction, "pred-1", "pvc-1", "encoded") + weightsContainer := findContainer(job.Spec.Template.Spec.InitContainers, "weights-placement") + if weightsContainer == nil { + t.Fatal("weights-placement init container missing") + } + + for _, arg := range weightsContainer.Command { + if strings.Contains(arg, prediction.Spec.Model.Weights.HTTP) { + t.Fatalf("URL must not be interpolated into command, found: %q", arg) + } + } + + envFound := false + for _, env := range weightsContainer.Env { + if env.Name == weightsURLEnvVar { + if env.Value != prediction.Spec.Model.Weights.HTTP { + t.Fatalf("WEIGHTS_URL env value mismatch: %q vs %q", env.Value, prediction.Spec.Model.Weights.HTTP) + } + envFound = true + } + } + if !envFound { + t.Fatalf("WEIGHTS_URL env var not found") + } +} + +func TestBuildPredictUsesAffinityNotNodeSelector(t *testing.T) { + builder := NewJobBuilder(NewNodeSelectorTranslator()) + prediction := makePrediction() + job := builder.BuildPredict(prediction, "pred-1", "pvc-1", "encoded") + + if job.Spec.Template.Spec.NodeSelector != nil { + t.Fatal("NodeSelector must remain unset; affinity is used instead") + } + if job.Spec.Template.Spec.Affinity == nil { + t.Fatal("expected Affinity to be populated from PredictionNodeSelector") + } + terms := job.Spec.Template.Spec.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms + if len(terms) != 1 || len(terms[0].MatchExpressions[0].Values) != 2 { + t.Fatalf("multi-value In operator was collapsed: %+v", terms) + } +} + +func TestBuildSearchUsesAffinity(t *testing.T) { + builder := NewJobBuilder(NewNodeSelectorTranslator()) + prediction := makePrediction() + prediction.Spec.Job.SearchNodeSelector = corev1.NodeSelector{ + NodeSelectorTerms: []corev1.NodeSelectorTerm{ + { + MatchExpressions: []corev1.NodeSelectorRequirement{ + {Key: "gpu", Operator: corev1.NodeSelectorOpExists}, + }, + }, + }, + } + job := builder.BuildSearch(prediction, "pred-search", "pvc-1", "encoded") + if job.Spec.Template.Spec.Affinity == nil { + t.Fatal("expected Affinity to be set") + } + terms := job.Spec.Template.Spec.Affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms + if terms[0].MatchExpressions[0].Operator != corev1.NodeSelectorOpExists { + t.Fatal("Exists operator was dropped during translation") + } +} + +func findContainer(containers []corev1.Container, name string) *corev1.Container { + for i := range containers { + if containers[i].Name == name { + return &containers[i] + } + } + return nil +} diff --git a/internal/prediction/nodeselector.go b/internal/prediction/nodeselector.go new file mode 100644 index 0000000..b4a0c25 --- /dev/null +++ b/internal/prediction/nodeselector.go @@ -0,0 +1,26 @@ +package prediction + +import ( + corev1 "k8s.io/api/core/v1" +) + +type NodeSelectorTranslator interface { + ToAffinity(selector *corev1.NodeSelector) *corev1.Affinity +} + +type nodeSelectorTranslator struct{} + +func NewNodeSelectorTranslator() NodeSelectorTranslator { + return &nodeSelectorTranslator{} +} + +func (t *nodeSelectorTranslator) ToAffinity(selector *corev1.NodeSelector) *corev1.Affinity { + if selector == nil || len(selector.NodeSelectorTerms) == 0 { + return nil + } + return &corev1.Affinity{ + NodeAffinity: &corev1.NodeAffinity{ + RequiredDuringSchedulingIgnoredDuringExecution: selector.DeepCopy(), + }, + } +} diff --git a/internal/prediction/nodeselector_test.go b/internal/prediction/nodeselector_test.go new file mode 100644 index 0000000..d17c0a5 --- /dev/null +++ b/internal/prediction/nodeselector_test.go @@ -0,0 +1,60 @@ +package prediction + +import ( + "testing" + + corev1 "k8s.io/api/core/v1" +) + +func TestToAffinityPreservesAllOperators(t *testing.T) { + translator := NewNodeSelectorTranslator() + + selector := &corev1.NodeSelector{ + NodeSelectorTerms: []corev1.NodeSelectorTerm{ + { + MatchExpressions: []corev1.NodeSelectorRequirement{ + {Key: "zone", Operator: corev1.NodeSelectorOpIn, Values: []string{"a", "b", "c"}}, + {Key: "role", Operator: corev1.NodeSelectorOpNotIn, Values: []string{"control-plane"}}, + {Key: "gpu", Operator: corev1.NodeSelectorOpExists}, + {Key: "spot", Operator: corev1.NodeSelectorOpDoesNotExist}, + {Key: "cpu", Operator: corev1.NodeSelectorOpGt, Values: []string{"4"}}, + {Key: "memory", Operator: corev1.NodeSelectorOpLt, Values: []string{"64"}}, + }, + MatchFields: []corev1.NodeSelectorRequirement{ + {Key: "metadata.name", Operator: corev1.NodeSelectorOpIn, Values: []string{"node-1"}}, + }, + }, + }, + } + + affinity := translator.ToAffinity(selector) + if affinity == nil { + t.Fatal("expected non-nil affinity") + } + if affinity.NodeAffinity == nil || affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution == nil { + t.Fatal("expected RequiredDuringSchedulingIgnoredDuringExecution to be populated") + } + terms := affinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms + if len(terms) != 1 { + t.Fatalf("expected 1 term, got %d", len(terms)) + } + if len(terms[0].MatchExpressions) != 6 { + t.Fatalf("expected 6 match expressions, got %d", len(terms[0].MatchExpressions)) + } + if got := terms[0].MatchExpressions[0].Values; len(got) != 3 { + t.Fatalf("expected 3 values for In operator, got %d", len(got)) + } + if len(terms[0].MatchFields) != 1 { + t.Fatalf("expected 1 match field, got %d", len(terms[0].MatchFields)) + } +} + +func TestToAffinityNilOrEmpty(t *testing.T) { + translator := NewNodeSelectorTranslator() + if affinity := translator.ToAffinity(nil); affinity != nil { + t.Fatal("expected nil affinity for nil selector") + } + if affinity := translator.ToAffinity(&corev1.NodeSelector{}); affinity != nil { + t.Fatal("expected nil affinity for empty terms") + } +} diff --git a/internal/prediction/phase.go b/internal/prediction/phase.go new file mode 100644 index 0000000..2871599 --- /dev/null +++ b/internal/prediction/phase.go @@ -0,0 +1,390 @@ +package prediction + +import ( + "context" + "fmt" + "time" + + batchv1 "k8s.io/api/batch/v1" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/types" + "k8s.io/client-go/tools/record" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" + logf "sigs.k8s.io/controller-runtime/pkg/log" + + datav1 "github.com/kubefold/operator/api/v1" + "github.com/kubefold/operator/internal/shared" +) + +const requeueAfterJobPoll = 10 * time.Second + +type PhaseRouter interface { + Handle(ctx context.Context, prediction *datav1.ProteinConformationPrediction) (ctrl.Result, error) +} + +type phaseRouter struct { + client client.Client + scheme *runtime.Scheme + recorder record.EventRecorder + pvc PVCBuilder + jobs JobBuilder + input InputEncoder + retry RetryPolicy + timeout TimeoutChecker + cleaner ResourceCleaner + conditions ConditionsManager + status StatusWriter +} + +func NewPhaseRouter( + c client.Client, + scheme *runtime.Scheme, + recorder record.EventRecorder, + pvc PVCBuilder, + jobs JobBuilder, + input InputEncoder, + retry RetryPolicy, + timeout TimeoutChecker, + cleaner ResourceCleaner, + conditions ConditionsManager, + status StatusWriter, +) PhaseRouter { + return &phaseRouter{ + client: c, + scheme: scheme, + recorder: recorder, + pvc: pvc, + jobs: jobs, + input: input, + retry: retry, + timeout: timeout, + cleaner: cleaner, + conditions: conditions, + status: status, + } +} + +func (r *phaseRouter) Handle(ctx context.Context, prediction *datav1.ProteinConformationPrediction) (ctrl.Result, error) { + log := logf.FromContext(ctx) + switch prediction.Status.Phase { + case datav1.ProteinConformationPredictionStatusPhaseNotStarted: + return r.handleNotStarted(ctx, prediction) + case datav1.ProteinConformationPredictionStatusPhaseAligning: + return r.handleAligning(ctx, prediction) + case datav1.ProteinConformationPredictionStatusPhasePredicting: + return r.handlePredicting(ctx, prediction) + case datav1.ProteinConformationPredictionStatusPhaseUploadingArtifacts: + return r.handleUploadingArtifacts(ctx, prediction) + case datav1.ProteinConformationPredictionStatusPhaseCompleted, datav1.ProteinConformationPredictionStatusPhaseFailed: + return ctrl.Result{}, nil + } + log.Info("Unknown phase", "Phase", prediction.Status.Phase) + return ctrl.Result{}, nil +} + +func (r *phaseRouter) handleNotStarted(ctx context.Context, prediction *datav1.ProteinConformationPrediction) (ctrl.Result, error) { + log := logf.FromContext(ctx) + + database := &datav1.ProteinDatabase{} + if err := r.client.Get(ctx, types.NamespacedName{Name: prediction.Spec.Database, Namespace: prediction.Namespace}, database); err != nil { + if errors.IsNotFound(err) { + log.Info("Waiting for ProteinDatabase to be created", "Database", prediction.Spec.Database) + r.recorder.Event(prediction, corev1.EventTypeNormal, "DatabaseNotFound", fmt.Sprintf("Waiting for ProteinDatabase %s to be created", prediction.Spec.Database)) + return ctrl.Result{RequeueAfter: requeueAfterJobPoll}, nil + } + log.Error(err, "Failed to get ProteinDatabase") + r.recorder.Event(prediction, corev1.EventTypeWarning, "DatabaseError", fmt.Sprintf("Failed to get ProteinDatabase: %v", err)) + return ctrl.Result{RequeueAfter: requeueAfterJobPoll}, nil + } + + pvcName := shared.PredictionDataPVCName(prediction.Name) + if result, err := r.ensurePVC(ctx, prediction, pvcName); err != nil || !result.IsZero() { + return result, err + } + + jobName := shared.SearchJobName(prediction.Name) + job := &batchv1.Job{} + err := r.client.Get(ctx, types.NamespacedName{Name: jobName, Namespace: prediction.Namespace}, job) + if err != nil && !errors.IsNotFound(err) { + log.Error(err, "Failed to get search job") + r.recorder.Event(prediction, corev1.EventTypeWarning, "JobError", fmt.Sprintf("Failed to get search job: %v", err)) + return ctrl.Result{}, err + } + if errors.IsNotFound(err) { + encodedInput, encodeErr := r.input.Encode(prediction, false) + if encodeErr != nil { + log.Error(encodeErr, "Failed to prepare FoldInput") + r.recorder.Event(prediction, corev1.EventTypeWarning, "InputError", fmt.Sprintf("Failed to prepare FoldInput: %v", encodeErr)) + return ctrl.Result{}, encodeErr + } + newJob := r.jobs.BuildSearch(prediction, jobName, pvcName, encodedInput) + if err := r.createOwnedJob(ctx, prediction, newJob); err != nil { + log.Error(err, "Failed to create search job") + r.recorder.Event(prediction, corev1.EventTypeWarning, "JobCreationError", fmt.Sprintf("Failed to create search job: %v", err)) + return ctrl.Result{}, err + } + r.recorder.Event(prediction, corev1.EventTypeNormal, "JobCreated", fmt.Sprintf("Created search job %s", jobName)) + } + + if err := r.status.Update(ctx, prediction, func(p *datav1.ProteinConformationPrediction) { + p.Status.Phase = datav1.ProteinConformationPredictionStatusPhaseAligning + r.conditions.Set(&p.Status, metav1.Condition{ + Type: ConditionTypeReady, + Status: metav1.ConditionFalse, + Reason: ReasonPhaseProgressed, + Message: "Aligning phase started", + }) + }); err != nil { + log.Error(err, "Failed to update ProteinConformationPrediction status") + return ctrl.Result{}, err + } + return ctrl.Result{Requeue: true}, nil +} + +func (r *phaseRouter) ensurePVC(ctx context.Context, prediction *datav1.ProteinConformationPrediction, pvcName string) (ctrl.Result, error) { + log := logf.FromContext(ctx) + pvc := &corev1.PersistentVolumeClaim{} + err := r.client.Get(ctx, types.NamespacedName{Name: pvcName, Namespace: prediction.Namespace}, pvc) + if err == nil { + return ctrl.Result{}, nil + } + if !errors.IsNotFound(err) { + log.Error(err, "Failed to get PVC") + r.recorder.Event(prediction, corev1.EventTypeWarning, "PVCError", fmt.Sprintf("Failed to get PVC: %v", err)) + return ctrl.Result{}, err + } + + newPVC := r.pvc.Build(prediction, pvcName) + if err := controllerutil.SetControllerReference(prediction, newPVC, r.scheme); err != nil { + log.Error(err, "Failed to set controller reference for PVC") + return ctrl.Result{}, err + } + if err := r.client.Create(ctx, newPVC); err != nil && !errors.IsAlreadyExists(err) { + log.Error(err, "Failed to create PVC") + r.recorder.Event(prediction, corev1.EventTypeWarning, "PVCCreationError", fmt.Sprintf("Failed to create PVC: %v", err)) + return ctrl.Result{}, err + } + r.recorder.Event(prediction, corev1.EventTypeNormal, "PVCCreated", fmt.Sprintf("Created PVC %s", pvcName)) + return ctrl.Result{Requeue: true}, nil +} + +func (r *phaseRouter) handleAligning(ctx context.Context, prediction *datav1.ProteinConformationPrediction) (ctrl.Result, error) { + return r.observeJob(ctx, prediction, jobObservation{ + phase: PhaseSearch, + jobName: shared.SearchJobName(prediction.Name), + nextPhase: datav1.ProteinConformationPredictionStatusPhasePredicting, + condType: ConditionTypeSearchSucceeded, + condMessage: "Search job completed", + }) +} + +func (r *phaseRouter) handlePredicting(ctx context.Context, prediction *datav1.ProteinConformationPrediction) (ctrl.Result, error) { + jobName := shared.PredictJobName(prediction.Name) + job := &batchv1.Job{} + err := r.client.Get(ctx, types.NamespacedName{Name: jobName, Namespace: prediction.Namespace}, job) + if errors.IsNotFound(err) { + return r.createPredictionJob(ctx, prediction, jobName) + } + if err != nil { + return ctrl.Result{}, err + } + return r.observeExistingJob(ctx, prediction, job, jobObservation{ + phase: PhasePredict, + jobName: jobName, + nextPhase: datav1.ProteinConformationPredictionStatusPhaseUploadingArtifacts, + condType: ConditionTypePredictSucceeded, + condMessage: "Prediction job completed", + }) +} + +func (r *phaseRouter) createPredictionJob(ctx context.Context, prediction *datav1.ProteinConformationPrediction, jobName string) (ctrl.Result, error) { + log := logf.FromContext(ctx) + encodedInput, err := r.input.Encode(prediction, true) + if err != nil { + log.Error(err, "Failed to prepare FoldInput") + return ctrl.Result{}, err + } + pvcName := shared.PredictionDataPVCName(prediction.Name) + job := r.jobs.BuildPredict(prediction, jobName, pvcName, encodedInput) + if err := r.createOwnedJob(ctx, prediction, job); err != nil { + log.Error(err, "Failed to create prediction job") + return ctrl.Result{}, err + } + return ctrl.Result{Requeue: true}, nil +} + +func (r *phaseRouter) createOwnedJob(ctx context.Context, prediction *datav1.ProteinConformationPrediction, job *batchv1.Job) error { + if err := controllerutil.SetControllerReference(prediction, job, r.scheme); err != nil { + return err + } + if err := r.client.Create(ctx, job); err != nil && !errors.IsAlreadyExists(err) { + return err + } + return nil +} + +func (r *phaseRouter) handleUploadingArtifacts(ctx context.Context, prediction *datav1.ProteinConformationPrediction) (ctrl.Result, error) { + jobName := shared.UploadJobName(prediction.Name) + job := &batchv1.Job{} + err := r.client.Get(ctx, types.NamespacedName{Name: jobName, Namespace: prediction.Namespace}, job) + if errors.IsNotFound(err) { + return r.createUploadJob(ctx, prediction, jobName) + } + if err != nil { + return ctrl.Result{}, err + } + + if job.Status.Succeeded > 0 { + return r.completeUpload(ctx, prediction) + } + if r.timeout.IsTimedOut(job, DefaultJobTimeout) { + return r.markFailedDueToTimeout(ctx, prediction, jobName) + } + if job.Status.Failed > 0 { + return r.retryOrFail(ctx, prediction, job, PhaseUpload, jobName, "Upload job failed") + } + return ctrl.Result{RequeueAfter: requeueAfterJobPoll}, nil +} + +func (r *phaseRouter) createUploadJob(ctx context.Context, prediction *datav1.ProteinConformationPrediction, jobName string) (ctrl.Result, error) { + log := logf.FromContext(ctx) + pvcName := shared.PredictionDataPVCName(prediction.Name) + job := r.jobs.BuildUpload(prediction, jobName, pvcName) + if err := r.createOwnedJob(ctx, prediction, job); err != nil { + log.Error(err, "Failed to create upload job") + return ctrl.Result{}, err + } + return ctrl.Result{Requeue: true}, nil +} + +func (r *phaseRouter) completeUpload(ctx context.Context, prediction *datav1.ProteinConformationPrediction) (ctrl.Result, error) { + log := logf.FromContext(ctx) + if err := r.status.Update(ctx, prediction, func(p *datav1.ProteinConformationPrediction) { + p.Status.Phase = datav1.ProteinConformationPredictionStatusPhaseCompleted + r.conditions.Set(&p.Status, metav1.Condition{ + Type: ConditionTypeUploadSucceeded, + Status: metav1.ConditionTrue, + Reason: ReasonCompleted, + Message: "Upload job completed", + }) + r.conditions.Set(&p.Status, metav1.Condition{ + Type: ConditionTypeReady, + Status: metav1.ConditionTrue, + Reason: ReasonCompleted, + Message: "Prediction artifacts uploaded", + }) + }); err != nil { + return ctrl.Result{}, err + } + + if err := r.cleaner.DeleteCompletedJobs(ctx, prediction.Name, prediction.Namespace); err != nil { + log.Error(err, "Failed to cleanup completed jobs") + return ctrl.Result{}, err + } + if err := r.cleaner.DeletePVCForPrediction(ctx, prediction.Name, prediction.Namespace); err != nil { + log.Error(err, "Failed to delete PVC") + return ctrl.Result{}, err + } + return ctrl.Result{}, nil +} + +type jobObservation struct { + phase Phase + jobName string + nextPhase datav1.ProteinConformationPredictionStatusPhase + condType string + condMessage string +} + +func (r *phaseRouter) observeJob(ctx context.Context, prediction *datav1.ProteinConformationPrediction, observation jobObservation) (ctrl.Result, error) { + job := &batchv1.Job{} + err := r.client.Get(ctx, types.NamespacedName{Name: observation.jobName, Namespace: prediction.Namespace}, job) + if errors.IsNotFound(err) { + return ctrl.Result{RequeueAfter: requeueAfterJobPoll}, nil + } + if err != nil { + return ctrl.Result{}, err + } + return r.observeExistingJob(ctx, prediction, job, observation) +} + +func (r *phaseRouter) observeExistingJob(ctx context.Context, prediction *datav1.ProteinConformationPrediction, job *batchv1.Job, observation jobObservation) (ctrl.Result, error) { + if job.Status.Succeeded > 0 { + return r.advancePhase(ctx, prediction, observation) + } + if r.timeout.IsTimedOut(job, DefaultJobTimeout) { + return r.markFailedDueToTimeout(ctx, prediction, observation.jobName) + } + if job.Status.Failed > 0 { + return r.retryOrFail(ctx, prediction, job, observation.phase, observation.jobName, fmt.Sprintf("%s job failed", observation.phase)) + } + return ctrl.Result{RequeueAfter: requeueAfterJobPoll}, nil +} + +func (r *phaseRouter) advancePhase(ctx context.Context, prediction *datav1.ProteinConformationPrediction, observation jobObservation) (ctrl.Result, error) { + if err := r.status.Update(ctx, prediction, func(p *datav1.ProteinConformationPrediction) { + p.Status.Phase = observation.nextPhase + r.conditions.Set(&p.Status, metav1.Condition{ + Type: observation.condType, + Status: metav1.ConditionTrue, + Reason: ReasonPhaseProgressed, + Message: observation.condMessage, + }) + r.conditions.Set(&p.Status, metav1.Condition{ + Type: ConditionTypeReady, + Status: metav1.ConditionFalse, + Reason: ReasonPhaseProgressed, + Message: fmt.Sprintf("Phase transitioned to %s", observation.nextPhase), + }) + }); err != nil { + return ctrl.Result{}, err + } + return ctrl.Result{Requeue: true}, nil +} + +func (r *phaseRouter) retryOrFail(ctx context.Context, prediction *datav1.ProteinConformationPrediction, job *batchv1.Job, phase Phase, jobName, message string) (ctrl.Result, error) { + if r.retry.AtLimit(&prediction.Status, phase) { + return r.markFailed(ctx, prediction, fmt.Sprintf("%s (job %s) after max retries", message, jobName), ReasonRetriesExhausted) + } + if err := r.status.Update(ctx, prediction, func(p *datav1.ProteinConformationPrediction) { + r.retry.Increment(&p.Status, phase) + }); err != nil { + return ctrl.Result{}, err + } + if err := r.client.Delete(ctx, job); err != nil && !errors.IsNotFound(err) { + return ctrl.Result{}, err + } + return ctrl.Result{Requeue: true}, nil +} + +func (r *phaseRouter) markFailedDueToTimeout(ctx context.Context, prediction *datav1.ProteinConformationPrediction, jobName string) (ctrl.Result, error) { + if err := r.cleaner.DeleteJobsForPrediction(ctx, prediction.Name, prediction.Namespace); err != nil { + return ctrl.Result{}, err + } + if err := r.cleaner.DeletePVCForPrediction(ctx, prediction.Name, prediction.Namespace); err != nil { + return ctrl.Result{}, err + } + return r.markFailed(ctx, prediction, fmt.Sprintf("Job %s timed out", jobName), ReasonJobTimeout) +} + +func (r *phaseRouter) markFailed(ctx context.Context, prediction *datav1.ProteinConformationPrediction, message, reason string) (ctrl.Result, error) { + if err := r.status.Update(ctx, prediction, func(p *datav1.ProteinConformationPrediction) { + p.Status.Phase = datav1.ProteinConformationPredictionStatusPhaseFailed + p.Status.Error = message + r.conditions.Set(&p.Status, metav1.Condition{ + Type: ConditionTypeReady, + Status: metav1.ConditionFalse, + Reason: reason, + Message: message, + }) + }); err != nil { + return ctrl.Result{}, err + } + r.recorder.Event(prediction, corev1.EventTypeWarning, reason, message) + return ctrl.Result{}, nil +} diff --git a/internal/prediction/pvc.go b/internal/prediction/pvc.go new file mode 100644 index 0000000..28f89d2 --- /dev/null +++ b/internal/prediction/pvc.go @@ -0,0 +1,62 @@ +package prediction + +import ( + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/resource" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + datav1 "github.com/kubefold/operator/api/v1" + "github.com/kubefold/operator/internal/shared" +) + +const ( + DefaultStorageClass = "fsx-sc" + predictionPVCSize = "10Gi" +) + +type PVCBuilder interface { + Build(prediction *datav1.ProteinConformationPrediction, pvcName string) *corev1.PersistentVolumeClaim +} + +type pvcBuilder struct{} + +func NewPVCBuilder() PVCBuilder { + return &pvcBuilder{} +} + +func (b *pvcBuilder) Build(prediction *datav1.ProteinConformationPrediction, pvcName string) *corev1.PersistentVolumeClaim { + storageClass := resolveStorageClass(prediction) + + pvc := &corev1.PersistentVolumeClaim{ + ObjectMeta: metav1.ObjectMeta{ + Name: pvcName, + Namespace: prediction.Namespace, + Labels: shared.PredictionPVCLabels(prediction.Name), + }, + Spec: corev1.PersistentVolumeClaimSpec{ + AccessModes: []corev1.PersistentVolumeAccessMode{corev1.ReadWriteOnce}, + Resources: corev1.VolumeResourceRequirements{ + Requests: corev1.ResourceList{ + corev1.ResourceStorage: resource.MustParse(predictionPVCSize), + }, + }, + StorageClassName: &storageClass, + }, + } + + if prediction.Spec.Model.Volume.Selector != nil { + pvc.Spec.Selector = prediction.Spec.Model.Volume.Selector + } + + return pvc +} + +func resolveStorageClass(prediction *datav1.ProteinConformationPrediction) string { + if prediction.Spec.Model.Volume.StorageClassName != nil && *prediction.Spec.Model.Volume.StorageClassName != "" { + return *prediction.Spec.Model.Volume.StorageClassName + } + if prediction.Spec.StorageClass != "" { + return prediction.Spec.StorageClass + } + return DefaultStorageClass +} diff --git a/internal/prediction/resources.go b/internal/prediction/resources.go new file mode 100644 index 0000000..d59bc35 --- /dev/null +++ b/internal/prediction/resources.go @@ -0,0 +1,55 @@ +package prediction + +import ( + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/resource" + + datav1 "github.com/kubefold/operator/api/v1" +) + +type phaseResources struct { + cpu string + memory string +} + +const ( + phaseSearchKey = "search" + phasePredictKey = "predict" + phaseUploadKey = "upload" +) + +var profileResourceMap = map[datav1.ProteinConformationPredictionProfile]map[string]phaseResources{ + datav1.ProteinConformationPredictionProfileSmall: { + phaseSearchKey: {cpu: "1500m", memory: "6Gi"}, + phasePredictKey: {cpu: "1", memory: "4Gi"}, + phaseUploadKey: {cpu: "100m", memory: "256Mi"}, + }, + datav1.ProteinConformationPredictionProfileMedium: { + phaseSearchKey: {cpu: "3", memory: "12Gi"}, + phasePredictKey: {cpu: "1", memory: "8Gi"}, + phaseUploadKey: {cpu: "100m", memory: "256Mi"}, + }, + datav1.ProteinConformationPredictionProfileLarge: { + phaseSearchKey: {cpu: "6", memory: "48Gi"}, + phasePredictKey: {cpu: "2", memory: "16Gi"}, + phaseUploadKey: {cpu: "100m", memory: "256Mi"}, + }, +} + +func resourcesForPhase(profile datav1.ProteinConformationPredictionProfile, phaseKey string) corev1.ResourceRequirements { + chosenProfile := profile + if chosenProfile == "" { + chosenProfile = datav1.ProteinConformationPredictionProfileMedium + } + phases, ok := profileResourceMap[chosenProfile] + if !ok { + phases = profileResourceMap[datav1.ProteinConformationPredictionProfileMedium] + } + requested := phases[phaseKey] + return corev1.ResourceRequirements{ + Requests: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse(requested.cpu), + corev1.ResourceMemory: resource.MustParse(requested.memory), + }, + } +} diff --git a/internal/prediction/retry.go b/internal/prediction/retry.go new file mode 100644 index 0000000..304242a --- /dev/null +++ b/internal/prediction/retry.go @@ -0,0 +1,56 @@ +package prediction + +import ( + datav1 "github.com/kubefold/operator/api/v1" +) + +const MaxRetries = int32(3) + +type Phase string + +const ( + PhaseSearch Phase = "search" + PhasePredict Phase = "predict" + PhaseUpload Phase = "upload" +) + +type RetryPolicy interface { + Counter(status *datav1.ProteinConformationPredictionStatus, phase Phase) int32 + Increment(status *datav1.ProteinConformationPredictionStatus, phase Phase) + AtLimit(status *datav1.ProteinConformationPredictionStatus, phase Phase) bool +} + +type retryPolicy struct { + max int32 +} + +func NewRetryPolicy(max int32) RetryPolicy { + return &retryPolicy{max: max} +} + +func (p *retryPolicy) Counter(status *datav1.ProteinConformationPredictionStatus, phase Phase) int32 { + switch phase { + case PhaseSearch: + return status.SearchRetryCount + case PhasePredict: + return status.PredictRetryCount + case PhaseUpload: + return status.UploadRetryCount + } + return 0 +} + +func (p *retryPolicy) Increment(status *datav1.ProteinConformationPredictionStatus, phase Phase) { + switch phase { + case PhaseSearch: + status.SearchRetryCount++ + case PhasePredict: + status.PredictRetryCount++ + case PhaseUpload: + status.UploadRetryCount++ + } +} + +func (p *retryPolicy) AtLimit(status *datav1.ProteinConformationPredictionStatus, phase Phase) bool { + return p.Counter(status, phase) >= p.max +} diff --git a/internal/prediction/retry_test.go b/internal/prediction/retry_test.go new file mode 100644 index 0000000..075a194 --- /dev/null +++ b/internal/prediction/retry_test.go @@ -0,0 +1,61 @@ +package prediction + +import ( + "testing" + + datav1 "github.com/kubefold/operator/api/v1" +) + +func TestRetryPolicyPerPhaseIndependence(t *testing.T) { + policy := NewRetryPolicy(3) + status := &datav1.ProteinConformationPredictionStatus{} + + policy.Increment(status, PhaseSearch) + policy.Increment(status, PhaseSearch) + if status.SearchRetryCount != 2 { + t.Fatalf("SearchRetryCount = %d, want 2", status.SearchRetryCount) + } + if status.PredictRetryCount != 0 || status.UploadRetryCount != 0 { + t.Fatal("Predict/Upload counters should not be touched by Search increment") + } + + policy.Increment(status, PhasePredict) + if status.PredictRetryCount != 1 { + t.Fatalf("PredictRetryCount = %d, want 1", status.PredictRetryCount) + } +} + +func TestRetryPolicyAtLimit(t *testing.T) { + policy := NewRetryPolicy(2) + status := &datav1.ProteinConformationPredictionStatus{} + + if policy.AtLimit(status, PhaseUpload) { + t.Fatal("should not be at limit initially") + } + policy.Increment(status, PhaseUpload) + if policy.AtLimit(status, PhaseUpload) { + t.Fatal("should not be at limit after 1 retry with max=2") + } + policy.Increment(status, PhaseUpload) + if !policy.AtLimit(status, PhaseUpload) { + t.Fatal("should be at limit after 2 retries with max=2") + } +} + +func TestRetryPolicyCounter(t *testing.T) { + policy := NewRetryPolicy(5) + status := &datav1.ProteinConformationPredictionStatus{ + SearchRetryCount: 3, + PredictRetryCount: 1, + UploadRetryCount: 0, + } + if got := policy.Counter(status, PhaseSearch); got != 3 { + t.Fatalf("Counter(Search) = %d, want 3", got) + } + if got := policy.Counter(status, PhasePredict); got != 1 { + t.Fatalf("Counter(Predict) = %d, want 1", got) + } + if got := policy.Counter(status, PhaseUpload); got != 0 { + t.Fatalf("Counter(Upload) = %d, want 0", got) + } +} diff --git a/internal/prediction/status.go b/internal/prediction/status.go new file mode 100644 index 0000000..d23013f --- /dev/null +++ b/internal/prediction/status.go @@ -0,0 +1,39 @@ +package prediction + +import ( + "context" + + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/client" + + datav1 "github.com/kubefold/operator/api/v1" + "github.com/kubefold/operator/internal/shared" +) + +type StatusWriter interface { + Update(ctx context.Context, prediction *datav1.ProteinConformationPrediction, mutate func(*datav1.ProteinConformationPrediction)) error +} + +type statusWriter struct { + client client.Client +} + +func NewStatusWriter(c client.Client) StatusWriter { + return &statusWriter{client: c} +} + +func (w *statusWriter) Update(ctx context.Context, prediction *datav1.ProteinConformationPrediction, mutate func(*datav1.ProteinConformationPrediction)) error { + return shared.RetryOnConflict(ctx, func() error { + latest := &datav1.ProteinConformationPrediction{} + if err := w.client.Get(ctx, types.NamespacedName{Name: prediction.Name, Namespace: prediction.Namespace}, latest); err != nil { + return err + } + mutate(latest) + if err := w.client.Status().Update(ctx, latest); err != nil { + return err + } + prediction.Status = latest.Status + prediction.ResourceVersion = latest.ResourceVersion + return nil + }) +} diff --git a/internal/prediction/timeout.go b/internal/prediction/timeout.go new file mode 100644 index 0000000..45fc652 --- /dev/null +++ b/internal/prediction/timeout.go @@ -0,0 +1,32 @@ +package prediction + +import ( + "time" + + batchv1 "k8s.io/api/batch/v1" +) + +const DefaultJobTimeout = 24 * time.Hour + +type TimeoutChecker interface { + IsTimedOut(job *batchv1.Job, defaultTimeout time.Duration) bool +} + +type timeoutChecker struct { + now func() time.Time +} + +func NewTimeoutChecker() TimeoutChecker { + return &timeoutChecker{now: time.Now} +} + +func (c *timeoutChecker) IsTimedOut(job *batchv1.Job, defaultTimeout time.Duration) bool { + if job.Status.StartTime == nil { + return false + } + timeout := defaultTimeout + if job.Spec.ActiveDeadlineSeconds != nil { + timeout = time.Duration(*job.Spec.ActiveDeadlineSeconds) * time.Second + } + return c.now().Sub(job.Status.StartTime.Time) > timeout +} diff --git a/internal/prediction/timeout_test.go b/internal/prediction/timeout_test.go new file mode 100644 index 0000000..8c9fe1a --- /dev/null +++ b/internal/prediction/timeout_test.go @@ -0,0 +1,52 @@ +package prediction + +import ( + "testing" + "time" + + batchv1 "k8s.io/api/batch/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +func TestTimeoutCheckerNoStartTime(t *testing.T) { + checker := &timeoutChecker{now: func() time.Time { return time.Unix(0, 0) }} + job := &batchv1.Job{} + if checker.IsTimedOut(job, time.Hour) { + t.Fatal("job without StartTime must not be timed out") + } +} + +func TestTimeoutCheckerExceededDefault(t *testing.T) { + start := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC) + checker := &timeoutChecker{now: func() time.Time { return start.Add(2 * time.Hour) }} + job := &batchv1.Job{ + Status: batchv1.JobStatus{StartTime: &metav1.Time{Time: start}}, + } + if !checker.IsTimedOut(job, time.Hour) { + t.Fatal("expected timeout after 2h when default is 1h") + } +} + +func TestTimeoutCheckerWithinDefault(t *testing.T) { + start := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC) + checker := &timeoutChecker{now: func() time.Time { return start.Add(30 * time.Minute) }} + job := &batchv1.Job{ + Status: batchv1.JobStatus{StartTime: &metav1.Time{Time: start}}, + } + if checker.IsTimedOut(job, time.Hour) { + t.Fatal("must not be timed out before threshold") + } +} + +func TestTimeoutCheckerUsesActiveDeadlineSeconds(t *testing.T) { + start := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC) + checker := &timeoutChecker{now: func() time.Time { return start.Add(2 * time.Minute) }} + deadline := int64(60) + job := &batchv1.Job{ + Spec: batchv1.JobSpec{ActiveDeadlineSeconds: &deadline}, + Status: batchv1.JobStatus{StartTime: &metav1.Time{Time: start}}, + } + if !checker.IsTimedOut(job, 24*time.Hour) { + t.Fatal("ActiveDeadlineSeconds=60 must take precedence over default 24h") + } +} diff --git a/internal/prediction/validate.go b/internal/prediction/validate.go new file mode 100644 index 0000000..9feb34f --- /dev/null +++ b/internal/prediction/validate.go @@ -0,0 +1,52 @@ +package prediction + +import ( + "fmt" + + datav1 "github.com/kubefold/operator/api/v1" + "github.com/kubefold/operator/internal/shared" +) + +const sequencePrefixLength = 10 + +type SpecValidator interface { + Validate(prediction *datav1.ProteinConformationPrediction) error + SequencePrefix(sequence string) string +} + +type specValidator struct { + allowedWeightsHosts []string +} + +func NewSpecValidator(allowedWeightsHosts []string) SpecValidator { + return &specValidator{allowedWeightsHosts: allowedWeightsHosts} +} + +func (v *specValidator) Validate(prediction *datav1.ProteinConformationPrediction) error { + if prediction.Spec.Protein.Sequence == "" { + return fmt.Errorf("protein sequence cannot be empty") + } + if prediction.Spec.Database == "" { + return fmt.Errorf("database reference cannot be empty") + } + if prediction.Spec.Destination.S3.Bucket == "" { + return fmt.Errorf("destination S3 bucket cannot be empty") + } + if prediction.Spec.Destination.S3.Region == "" { + return fmt.Errorf("destination S3 region cannot be empty") + } + if err := shared.ValidateHTTPSURL( + prediction.Spec.Model.Weights.HTTP, + shared.WithAllowedHosts(v.allowedWeightsHosts...), + ); err != nil { + return fmt.Errorf("model weights URL is invalid: %w", err) + } + return nil +} + +func (v *specValidator) SequencePrefix(sequence string) string { + if len(sequence) <= sequencePrefixLength { + return sequence + } + return sequence[:sequencePrefixLength] + "..." +} diff --git a/internal/prediction/validate_test.go b/internal/prediction/validate_test.go new file mode 100644 index 0000000..bf7c285 --- /dev/null +++ b/internal/prediction/validate_test.go @@ -0,0 +1,96 @@ +package prediction + +import ( + "strings" + "testing" + + datav1 "github.com/kubefold/operator/api/v1" +) + +func TestSequencePrefix(t *testing.T) { + v := NewSpecValidator(nil) + + cases := []struct { + name string + sequence string + want string + }{ + {"empty", "", ""}, + {"short under threshold", "ABCD", "ABCD"}, + {"exactly threshold", "ABCDEFGHIJ", "ABCDEFGHIJ"}, + {"over threshold", "ABCDEFGHIJKLMNO", "ABCDEFGHIJ..."}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got := v.SequencePrefix(tc.sequence) + if got != tc.want { + t.Fatalf("SequencePrefix(%q) = %q, want %q", tc.sequence, got, tc.want) + } + }) + } +} + +func TestValidateSpec(t *testing.T) { + makeValid := func() *datav1.ProteinConformationPrediction { + return &datav1.ProteinConformationPrediction{ + Spec: datav1.ProteinConformationPredictionSpec{ + Protein: datav1.ProteinConformationPredictionProtein{ + Sequence: "ABCDEFG", + ID: []string{"id-1"}, + }, + Database: "db-1", + Destination: datav1.ProteinConformationPredictionDestination{ + S3: datav1.ProteinConformationPredictionDestinationS3{ + Bucket: "bucket", + Region: "us-east-1", + }, + }, + Model: datav1.ProteinConformationPredictionModel{ + Weights: datav1.ProteinConformationPredictionModelWeights{ + HTTP: "https://example.com/weights.zst", + }, + }, + }, + } + } + + cases := []struct { + name string + mutate func(*datav1.ProteinConformationPrediction) + wantErr bool + errSubstr string + }{ + {"valid", func(p *datav1.ProteinConformationPrediction) {}, false, ""}, + {"empty sequence", func(p *datav1.ProteinConformationPrediction) { p.Spec.Protein.Sequence = "" }, true, "sequence"}, + {"empty database", func(p *datav1.ProteinConformationPrediction) { p.Spec.Database = "" }, true, "database"}, + {"empty bucket", func(p *datav1.ProteinConformationPrediction) { p.Spec.Destination.S3.Bucket = "" }, true, "bucket"}, + {"empty region", func(p *datav1.ProteinConformationPrediction) { p.Spec.Destination.S3.Region = "" }, true, "region"}, + {"http scheme", func(p *datav1.ProteinConformationPrediction) { p.Spec.Model.Weights.HTTP = "http://example.com/x" }, true, "URL"}, + {"shell injection in url", func(p *datav1.ProteinConformationPrediction) { + p.Spec.Model.Weights.HTTP = "https://example.com/x;rm -rf /" + }, true, "URL"}, + {"empty weights url", func(p *datav1.ProteinConformationPrediction) { p.Spec.Model.Weights.HTTP = "" }, true, "URL"}, + } + + v := NewSpecValidator(nil) + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + prediction := makeValid() + tc.mutate(prediction) + err := v.Validate(prediction) + if tc.wantErr { + if err == nil { + t.Fatalf("expected error, got nil") + } + if !strings.Contains(err.Error(), tc.errSubstr) { + t.Fatalf("error %q does not contain %q", err.Error(), tc.errSubstr) + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) + } +} diff --git a/internal/shared/conditions.go b/internal/shared/conditions.go new file mode 100644 index 0000000..c4d2678 --- /dev/null +++ b/internal/shared/conditions.go @@ -0,0 +1,17 @@ +package shared + +import ( + apimeta "k8s.io/apimachinery/pkg/api/meta" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +func SetCondition(conditions *[]metav1.Condition, condition metav1.Condition) { + if condition.LastTransitionTime.IsZero() { + condition.LastTransitionTime = metav1.Now() + } + apimeta.SetStatusCondition(conditions, condition) +} + +func FindCondition(conditions []metav1.Condition, conditionType string) *metav1.Condition { + return apimeta.FindStatusCondition(conditions, conditionType) +} diff --git a/internal/shared/delete.go b/internal/shared/delete.go new file mode 100644 index 0000000..1c6ca49 --- /dev/null +++ b/internal/shared/delete.go @@ -0,0 +1,13 @@ +package shared + +import ( + "context" + + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "sigs.k8s.io/controller-runtime/pkg/client" +) + +func DeleteInBackground(ctx context.Context, c client.Client, object client.Object) error { + propagation := metav1.DeletePropagationBackground + return client.IgnoreNotFound(c.Delete(ctx, object, &client.DeleteOptions{PropagationPolicy: &propagation})) +} diff --git a/internal/shared/labels.go b/internal/shared/labels.go new file mode 100644 index 0000000..95aaf3e --- /dev/null +++ b/internal/shared/labels.go @@ -0,0 +1,58 @@ +package shared + +import "maps" + +const ( + LabelApp = "app" + LabelDatabase = "data.kubefold.io/database" + LabelPrediction = "data.kubefold.io/prediction" + LabelDataset = "data.kubefold.io/dataset" + LabelStep = "data.kubefold.io/step" + LabelAppName = "app.kubernetes.io/name" + LabelAppInstance = "app.kubernetes.io/instance" + LabelAppManagedBy = "app.kubernetes.io/managed-by" + ManagedByOperator = "kubefold-operator" +) + +func DatabaseLabels(databaseName, appName string) map[string]string { + return map[string]string{ + LabelDatabase: databaseName, + LabelAppName: appName, + LabelAppInstance: databaseName, + LabelAppManagedBy: ManagedByOperator, + } +} + +func DownloaderJobLabels(databaseName, datasetShortName string) map[string]string { + labels := DatabaseLabels(databaseName, "proteindatabase-downloader") + labels[LabelDataset] = datasetShortName + return labels +} + +func PredictionLabels(predictionName, appName, step string) map[string]string { + return map[string]string{ + LabelApp: predictionName, + LabelPrediction: predictionName, + LabelStep: step, + LabelAppName: appName, + LabelAppInstance: predictionName, + LabelAppManagedBy: ManagedByOperator, + } +} + +func PredictionPVCLabels(predictionName string) map[string]string { + return map[string]string{ + LabelApp: predictionName, + LabelPrediction: predictionName, + LabelAppName: "proteinconformationprediction-data", + LabelAppInstance: predictionName, + LabelAppManagedBy: ManagedByOperator, + } +} + +func MergeLabels(base, overrides map[string]string) map[string]string { + merged := make(map[string]string, len(base)+len(overrides)) + maps.Copy(merged, base) + maps.Copy(merged, overrides) + return merged +} diff --git a/internal/shared/names.go b/internal/shared/names.go new file mode 100644 index 0000000..4ac9f3e --- /dev/null +++ b/internal/shared/names.go @@ -0,0 +1,35 @@ +package shared + +import "fmt" + +const ( + dataPVCSuffix = "-data" + searchJobSuffix = "-search" + predictJobSuffix = "-predict" + uploadJobSuffix = "-upload" + downloaderSuffix = "-downloader" +) + +func PredictionDataPVCName(predictionName string) string { + return predictionName + dataPVCSuffix +} + +func DatabasePVCName(databaseName string) string { + return databaseName + dataPVCSuffix +} + +func SearchJobName(predictionName string) string { + return predictionName + searchJobSuffix +} + +func PredictJobName(predictionName string) string { + return predictionName + predictJobSuffix +} + +func UploadJobName(predictionName string) string { + return predictionName + uploadJobSuffix +} + +func DownloaderJobName(databaseName, datasetShortName string) string { + return fmt.Sprintf("%s-%s%s", databaseName, datasetShortName, downloaderSuffix) +} diff --git a/internal/shared/retry.go b/internal/shared/retry.go new file mode 100644 index 0000000..2f7667d --- /dev/null +++ b/internal/shared/retry.go @@ -0,0 +1,16 @@ +package shared + +import ( + "context" + + "k8s.io/client-go/util/retry" +) + +func RetryOnConflict(ctx context.Context, fn func() error) error { + return retry.RetryOnConflict(retry.DefaultBackoff, func() error { + if err := ctx.Err(); err != nil { + return err + } + return fn() + }) +} diff --git a/internal/shared/url.go b/internal/shared/url.go new file mode 100644 index 0000000..0868757 --- /dev/null +++ b/internal/shared/url.go @@ -0,0 +1,61 @@ +package shared + +import ( + "fmt" + "net/url" + "slices" + "strings" +) + +const forbiddenURLCharacters = "`$();|&<>\n\r\t\"'\\ " + +type URLPolicy struct { + allowedHosts []string +} + +type URLOption func(*URLPolicy) + +func WithAllowedHosts(hosts ...string) URLOption { + return func(p *URLPolicy) { + p.allowedHosts = hosts + } +} + +func ValidateHTTPSURL(raw string, options ...URLOption) error { + if raw == "" { + return fmt.Errorf("URL cannot be empty") + } + if strings.ContainsAny(raw, forbiddenURLCharacters) { + return fmt.Errorf("URL contains forbidden characters") + } + + parsed, err := url.Parse(raw) + if err != nil { + return fmt.Errorf("invalid URL: %w", err) + } + if parsed.Scheme != "https" { + return fmt.Errorf("URL scheme must be https, got %q", parsed.Scheme) + } + if parsed.Host == "" { + return fmt.Errorf("URL host cannot be empty") + } + if parsed.User != nil { + return fmt.Errorf("URL must not contain userinfo") + } + if parsed.Fragment != "" { + return fmt.Errorf("URL must not contain fragment") + } + + policy := URLPolicy{} + for _, apply := range options { + apply(&policy) + } + if len(policy.allowedHosts) > 0 { + host := parsed.Hostname() + if !slices.Contains(policy.allowedHosts, host) { + return fmt.Errorf("URL host %q is not in the allowlist", host) + } + } + + return nil +} diff --git a/internal/shared/url_test.go b/internal/shared/url_test.go new file mode 100644 index 0000000..3bad450 --- /dev/null +++ b/internal/shared/url_test.go @@ -0,0 +1,41 @@ +package shared + +import "testing" + +func TestValidateHTTPSURL(t *testing.T) { + cases := []struct { + name string + url string + options []URLOption + wantPass bool + }{ + {"valid https", "https://example.com/weights.zst", nil, true}, + {"valid https with port", "https://example.com:8443/weights.zst", nil, true}, + {"empty", "", nil, false}, + {"http scheme", "http://example.com/weights", nil, false}, + {"ftp scheme", "ftp://example.com/weights", nil, false}, + {"shell injection semicolon", "https://example.com/weights;rm -rf /", nil, false}, + {"shell injection backtick", "https://example.com/`whoami`", nil, false}, + {"shell injection dollar", "https://example.com/$(whoami)", nil, false}, + {"shell injection pipe", "https://example.com/x|cat", nil, false}, + {"shell injection ampersand", "https://example.com/x&&id", nil, false}, + {"newline", "https://example.com/x\n", nil, false}, + {"userinfo", "https://user:pass@example.com/weights", nil, false}, + {"fragment", "https://example.com/weights#frag", nil, false}, + {"allowlist match", "https://trusted.example.com/file", []URLOption{WithAllowedHosts("trusted.example.com")}, true}, + {"allowlist miss", "https://evil.example.com/file", []URLOption{WithAllowedHosts("trusted.example.com")}, false}, + {"allowlist multiple match", "https://b.com/x", []URLOption{WithAllowedHosts("a.com", "b.com")}, true}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + err := ValidateHTTPSURL(tc.url, tc.options...) + if tc.wantPass && err != nil { + t.Fatalf("expected pass, got error: %v", err) + } + if !tc.wantPass && err == nil { + t.Fatalf("expected error for url %q", tc.url) + } + }) + } +} diff --git a/internal/util/format.go b/internal/util/format.go index 8053778..92e03ff 100644 --- a/internal/util/format.go +++ b/internal/util/format.go @@ -30,5 +30,8 @@ func FormatSpeed(speed int64) string { } func FormatPercentage(a, b int64) string { + if b <= 0 { + return "0.0%" + } return fmt.Sprintf("%.1f%%", float64(a)/float64(b)*100) } diff --git a/internal/util/format_test.go b/internal/util/format_test.go new file mode 100644 index 0000000..359cde4 --- /dev/null +++ b/internal/util/format_test.go @@ -0,0 +1,26 @@ +package util + +import "testing" + +const zeroPercent = "0.0%" + +func TestFormatPercentageZeroDenominator(t *testing.T) { + if got := FormatPercentage(100, 0); got != zeroPercent { + t.Fatalf("FormatPercentage(100, 0) = %q, want %s", got, zeroPercent) + } + if got := FormatPercentage(100, -1); got != zeroPercent { + t.Fatalf("FormatPercentage(100, -1) = %q, want %s", got, zeroPercent) + } +} + +func TestFormatPercentage(t *testing.T) { + if got := FormatPercentage(50, 100); got != "50.0%" { + t.Fatalf("FormatPercentage(50, 100) = %q, want 50.0%%", got) + } + if got := FormatPercentage(0, 100); got != zeroPercent { + t.Fatalf("FormatPercentage(0, 100) = %q, want %s", got, zeroPercent) + } + if got := FormatPercentage(100, 100); got != "100.0%" { + t.Fatalf("FormatPercentage(100, 100) = %q, want 100.0%%", got) + } +} diff --git a/up.sh b/up.sh index 0985c2e..4f37581 100755 --- a/up.sh +++ b/up.sh @@ -1,5 +1,6 @@ #!/bin/bash +export AWS_PROFILE=solidchat export AWS_PAGER="" ACCOUNT_ID=$(aws sts get-caller-identity --query Account --output text) REGION="eu-central-1"