diff --git a/rill-flow-service/src/main/java/com/weibo/rill/flow/service/dispatcher/FlowProtocolDispatcher.java b/rill-flow-service/src/main/java/com/weibo/rill/flow/service/dispatcher/FlowProtocolDispatcher.java index 70ac6866..b3f4a284 100644 --- a/rill-flow-service/src/main/java/com/weibo/rill/flow/service/dispatcher/FlowProtocolDispatcher.java +++ b/rill-flow-service/src/main/java/com/weibo/rill/flow/service/dispatcher/FlowProtocolDispatcher.java @@ -73,6 +73,7 @@ public String handle(Resource resource, DispatchInfo dispatchInfo) { DAGSettings dagSettings = DAGSettings.builder() .ignoreExist(false) .dagMaxDepth(bizDConfs.getFlowDAGMaxDepth()).build(); + olympicene.submit(executionId, dag, data, dagSettings, notifyInfo); dagResourceStatistic.updateFlowTypeResourceStatus(parentDAGExecutionId, parentTaskName, resource.getResourceName(), dag); ProfileActions.recordTinyDAGSubmit(executionId); diff --git a/rill-flow-service/src/main/java/com/weibo/rill/flow/service/dispatcher/FunctionProtocolDispatcher.java b/rill-flow-service/src/main/java/com/weibo/rill/flow/service/dispatcher/FunctionProtocolDispatcher.java index ddddb630..c303a29f 100644 --- a/rill-flow-service/src/main/java/com/weibo/rill/flow/service/dispatcher/FunctionProtocolDispatcher.java +++ b/rill-flow-service/src/main/java/com/weibo/rill/flow/service/dispatcher/FunctionProtocolDispatcher.java @@ -70,6 +70,7 @@ public String handle(Resource resource, DispatchInfo dispatchInfo) { int maxInvokeTime = switcherManagerImpl.getSwitcherState("ENABLE_FUNCTION_DISPATCH_RET_CHECK") ? 2 : 1; HttpMethod method = Optional.ofNullable(requestType).map(String::toUpperCase).map(HttpMethod::resolve).orElse(HttpMethod.POST); HttpEntity requestEntity = buildHttpEntity(method, header, requestParams); + String ret = httpInvokeHelper.invokeRequest(executionId, taskInfoName, url, requestEntity, method, maxInvokeTime); dagResourceStatistic.updateUrlTypeResourceStatus(executionId, taskInfoName, resource.getResourceName(), ret); return ret; diff --git a/rill-flow-service/src/test/groovy/com/weibo/rill/flow/service/dispatcher/FlowProtocolDispatcherTest.groovy b/rill-flow-service/src/test/groovy/com/weibo/rill/flow/service/dispatcher/FlowProtocolDispatcherTest.groovy new file mode 100644 index 00000000..23c6363b --- /dev/null +++ b/rill-flow-service/src/test/groovy/com/weibo/rill/flow/service/dispatcher/FlowProtocolDispatcherTest.groovy @@ -0,0 +1,179 @@ +/* + * Copyright 2021-2023 Weibo, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.weibo.rill.flow.service.dispatcher + +import com.alibaba.fastjson.JSON +import com.weibo.rill.flow.interfaces.model.resource.Resource +import com.weibo.rill.flow.interfaces.model.strategy.DispatchInfo +import com.weibo.rill.flow.interfaces.model.task.FunctionPattern +import com.weibo.rill.flow.interfaces.model.task.FunctionTask +import com.weibo.rill.flow.interfaces.model.task.TaskInfo +import com.weibo.rill.flow.olympicene.core.model.dag.DAG +import com.weibo.rill.flow.olympicene.traversal.Olympicene +import com.weibo.rill.flow.service.dconfs.BizDConfs +import com.weibo.rill.flow.service.service.DAGDescriptorService +import com.weibo.rill.flow.service.statistic.DAGResourceStatistic +import spock.lang.Specification + +class FlowProtocolDispatcherTest extends Specification { + FlowProtocolDispatcher dispatcher + DAGDescriptorService dagDescriptorService + BizDConfs bizDConfs + Olympicene olympicene + DAGResourceStatistic dagResourceStatistic + + def setup() { + dispatcher = new FlowProtocolDispatcher() + dagDescriptorService = Mock(DAGDescriptorService) + bizDConfs = Mock(BizDConfs) + olympicene = Mock(Olympicene) + dagResourceStatistic = Mock(DAGResourceStatistic) + + dispatcher.dagDescriptorService = dagDescriptorService + dispatcher.bizDConfs = bizDConfs + dispatcher.olympicene = olympicene + dispatcher.dagResourceStatistic = dagResourceStatistic + } + + def "test handle method with valid input"() { + given: + def resource = Mock(Resource) { + getSchemeValue() >> "test-scheme" + getResourceName() >> "test-resource" + } + def taskInfo = Mock(TaskInfo) { + getName() >> "test-task" + getTask() >> Mock(FunctionTask) { + getPattern() >> FunctionPattern.FLOW_ASYNC + } + } + def dispatchInfo = Mock(DispatchInfo) { + getExecutionId() >> "parent-execution-id" + getTaskInfo() >> taskInfo + getInput() >> ["uid": "123", "key": "value"] + } + def dag = Mock(DAG) + + when: + def result = dispatcher.handle(resource, dispatchInfo) + + then: + 1 * bizDConfs.getFlowDAGMaxDepth() >> 10 + 1 * dagDescriptorService.getDAG(123L, _, "test-scheme") >> dag + 1 * olympicene.submit(_, dag, _, _, _) + 1 * dagResourceStatistic.updateFlowTypeResourceStatus("parent-execution-id", "test-task", "test-resource", dag) + + and: + def jsonResult = JSON.parseObject(result) + jsonResult.containsKey("execution_id") + jsonResult.get("execution_id") != null + } + + def "test handle method with null input map"() { + given: + def resource = Mock(Resource) { + getSchemeValue() >> "test-scheme" + getResourceName() >> "test-resource" + } + def taskInfo = Mock(TaskInfo) { + getName() >> "test-task" + getTask() >> Mock(FunctionTask) { + getPattern() >> FunctionPattern.FLOW_ASYNC + } + } + def dispatchInfo = Mock(DispatchInfo) { + getExecutionId() >> "parent-execution-id" + getTaskInfo() >> taskInfo + getInput() >> null + } + def dag = Mock(DAG) + + when: + def result = dispatcher.handle(resource, dispatchInfo) + + then: + 1 * bizDConfs.getFlowDAGMaxDepth() >> 10 + 1 * dagDescriptorService.getDAG(0L, _, "test-scheme") >> dag + 1 * olympicene.submit(_, dag, _, _, _) + 1 * dagResourceStatistic.updateFlowTypeResourceStatus("parent-execution-id", "test-task", "test-resource", dag) + + and: + def jsonResult = JSON.parseObject(result) + jsonResult.containsKey("execution_id") + jsonResult.get("execution_id") != null + } + + def "test handle method with invalid uid"() { + given: + def resource = Mock(Resource) { + getSchemeValue() >> "test-scheme" + getResourceName() >> "test-resource" + } + def taskInfo = Mock(TaskInfo) { + getName() >> "test-task" + getTask() >> Mock(FunctionTask) { + getPattern() >> FunctionPattern.FLOW_ASYNC + } + } + def dispatchInfo = Mock(DispatchInfo) { + getExecutionId() >> "parent-execution-id" + getTaskInfo() >> taskInfo + getInput() >> ["uid": null, "key": "value"] + } + def dag = Mock(DAG) + + when: + def result = dispatcher.handle(resource, dispatchInfo) + + then: + 1 * bizDConfs.getFlowDAGMaxDepth() >> 10 + 1 * dagDescriptorService.getDAG(0L, _, "test-scheme") >> dag + 1 * olympicene.submit(_, dag, _, _, _) + 1 * dagResourceStatistic.updateFlowTypeResourceStatus("parent-execution-id", "test-task", "test-resource", dag) + + and: + def jsonResult = JSON.parseObject(result) + jsonResult.containsKey("execution_id") + jsonResult.get("execution_id") != null + } + + def "test handle method with non-numeric uid"() { + given: + def resource = Mock(Resource) { + getSchemeValue() >> "test-scheme" + getResourceName() >> "test-resource" + } + def taskInfo = Mock(TaskInfo) { + getName() >> "test-task" + getTask() >> Mock(FunctionTask) { + getPattern() >> FunctionPattern.FLOW_ASYNC + } + } + def dispatchInfo = Mock(DispatchInfo) { + getExecutionId() >> "parent-execution-id" + getTaskInfo() >> taskInfo + getInput() >> ["uid": "not-a-number", "key": "value"] + } + def dag = Mock(DAG) + + when: + def result = dispatcher.handle(resource, dispatchInfo) + + then: + thrown(NumberFormatException) + } +} diff --git a/rill-flow-service/src/test/groovy/com/weibo/rill/flow/service/dispatcher/FunctionProtocolDispatcherTest.groovy b/rill-flow-service/src/test/groovy/com/weibo/rill/flow/service/dispatcher/FunctionProtocolDispatcherTest.groovy index e98fcff0..05216fc3 100644 --- a/rill-flow-service/src/test/groovy/com/weibo/rill/flow/service/dispatcher/FunctionProtocolDispatcherTest.groovy +++ b/rill-flow-service/src/test/groovy/com/weibo/rill/flow/service/dispatcher/FunctionProtocolDispatcherTest.groovy @@ -16,40 +16,169 @@ package com.weibo.rill.flow.service.dispatcher +import com.weibo.rill.flow.common.exception.TaskException import com.weibo.rill.flow.interfaces.model.http.HttpParameter +import com.weibo.rill.flow.interfaces.model.resource.Resource +import com.weibo.rill.flow.interfaces.model.strategy.DispatchInfo +import com.weibo.rill.flow.interfaces.model.task.FunctionTask +import com.weibo.rill.flow.interfaces.model.task.TaskInfo +import com.weibo.rill.flow.olympicene.core.switcher.SwitcherManager +import com.weibo.rill.flow.service.invoke.HttpInvokeHelper +import com.weibo.rill.flow.service.statistic.DAGResourceStatistic +import org.springframework.http.HttpEntity +import org.springframework.http.HttpHeaders import org.springframework.http.HttpMethod import org.springframework.http.MediaType import org.springframework.util.LinkedMultiValueMap -import org.springframework.util.MultiValueMap +import org.springframework.web.client.RestClientResponseException import spock.lang.Specification +import spock.lang.Subject class FunctionProtocolDispatcherTest extends Specification { - FunctionProtocolDispatcher dispatcher = new FunctionProtocolDispatcher(); + @Subject + FunctionProtocolDispatcher dispatcher - def "buildHttpEntity test"() { + HttpInvokeHelper httpInvokeHelper + DAGResourceStatistic dagResourceStatistic + SwitcherManager switcherManager + + def setup() { + httpInvokeHelper = Mock(HttpInvokeHelper) + dagResourceStatistic = Mock(DAGResourceStatistic) + switcherManager = Mock(SwitcherManager) + dispatcher = new FunctionProtocolDispatcher( + httpInvokeHelper: httpInvokeHelper, + dagResourceStatistic: dagResourceStatistic, + switcherManagerImpl: switcherManager + ) + } + + def "should handle POST request successfully"() { + given: + def executionId = "exec-123" + def taskName = "testTask" + def resource = Mock(Resource) + def input = [key: "value"] + def taskInfo = new TaskInfo(name: taskName, task: new FunctionTask(taskName, null, null, "function", null, false, null, null, null, null, null, null, null, null, null, null, null, null, "POST", false, null, null, null, null, null, null)) + def headers = new LinkedMultiValueMap() + def dispatchInfo = Mock(DispatchInfo) { + getExecutionId() >> executionId + getInput() >> input + getTaskInfo() >> taskInfo + getHeaders() >> headers + } + def requestParams = Mock(HttpParameter) { + getHeader() >> [contentType: MediaType.APPLICATION_JSON_VALUE] + } + def url = "http://test.com/api" + def expectedResponse = '{"status": "success"}' + + when: + def result = dispatcher.handle(resource, dispatchInfo) + + then: + 1 * switcherManager.getSwitcherState("ENABLE_FUNCTION_DISPATCH_RET_CHECK") >> false + 1 * httpInvokeHelper.functionRequestParams(executionId, taskName, resource, input) >> requestParams + 1 * httpInvokeHelper.buildUrl(resource, requestParams.queryParams) >> url + 1 * httpInvokeHelper.invokeRequest(executionId, taskName, url, _ as HttpEntity, HttpMethod.POST, 1) >> expectedResponse + 1 * dagResourceStatistic.updateUrlTypeResourceStatus(executionId, taskName, _, expectedResponse) + result == expectedResponse + } + + def "should handle GET request successfully"() { given: - def httpParameter = HttpParameter.builder() - .header(inputHeader) - .body(inputBody) - .build() - MultiValueMap header = new LinkedMultiValueMap<>() - Optional.ofNullable(httpParameter.getHeader()) - .ifPresent { it -> it.forEach { key, value -> header.add(key, value) } } + def executionId = "exec-123" + def taskName = "testTask" + def resource = Mock(Resource) + def input = [key: "value"] + def taskInfo = new TaskInfo(name: taskName, task: new FunctionTask(taskName, null, null, "function", null, false, null, null, null, null, null, null, null, null, null, null, null, null, "GET", false, null, null, null, null, null, null)) + def headers = new LinkedMultiValueMap() + def dispatchInfo = Mock(DispatchInfo) { + getExecutionId() >> executionId + getInput() >> input + getTaskInfo() >> taskInfo + getHeaders() >> headers + } + def requestParams = Mock(HttpParameter) + def url = "http://test.com/api" + def expectedResponse = '{"status": "success"}' when: - def httpEntity = dispatcher.buildHttpEntity(method, header, httpParameter) + def result = dispatcher.handle(resource, dispatchInfo) then: - httpEntity.body == body - - where: - method | inputHeader | inputBody | body - null | [:] | [:] | null - HttpMethod.GET | [:] | [:] | null - HttpMethod.POST | [:] | [:] | [:] - HttpMethod.POST | [:] | [k: "v", user: [name: "Bob"]] | [k: "v", user: [name: "Bob"]] - HttpMethod.POST | ["Content-Type": MediaType.APPLICATION_JSON_VALUE] | [k: "v", user: [name: "Bob"]] | [k: "v", user: [name: "Bob"]] - HttpMethod.POST | ["Content-Type": MediaType.APPLICATION_FORM_URLENCODED_VALUE] | [k: "v", name: "Bob"] | [k: ["v"], name: ["Bob"]] + 1 * switcherManager.getSwitcherState("ENABLE_FUNCTION_DISPATCH_RET_CHECK") >> false + 1 * httpInvokeHelper.functionRequestParams(executionId, taskName, resource, input) >> requestParams + 1 * httpInvokeHelper.buildUrl(resource, requestParams.queryParams) >> url + 1 * httpInvokeHelper.invokeRequest(executionId, taskName, url, _ as HttpEntity, HttpMethod.GET, 1) >> expectedResponse + 1 * dagResourceStatistic.updateUrlTypeResourceStatus(executionId, taskName, _, expectedResponse) + result == expectedResponse } + def "should handle error response correctly"() { + given: + def executionId = "exec-123" + def taskName = "testTask" + def resource = Mock(Resource) + def input = [key: "value"] + def taskInfo = new TaskInfo(name: taskName, task: new FunctionTask(taskName, null, null, "function", null, false, null, null, null, null, null, null, null, null, null, null, null, null, "POST", false, null, null, null, null, null, null)) + def headers = new LinkedMultiValueMap() + def dispatchInfo = Mock(DispatchInfo) { + getExecutionId() >> executionId + getInput() >> input + getTaskInfo() >> taskInfo + getHeaders() >> headers + } + def requestParams = Mock(HttpParameter) + def url = "http://test.com/api" + def errorResponse = "Error occurred" + def exception = Mock(RestClientResponseException) { + getResponseBodyAsString() >> errorResponse + getRawStatusCode() >> 500 + } + + when: + dispatcher.handle(resource, dispatchInfo) + + then: + 1 * switcherManager.getSwitcherState("ENABLE_FUNCTION_DISPATCH_RET_CHECK") >> false + 1 * httpInvokeHelper.functionRequestParams(executionId, taskName, resource, input) >> requestParams + 1 * httpInvokeHelper.buildUrl(resource, requestParams.queryParams) >> url + 1 * httpInvokeHelper.invokeRequest(executionId, taskName, url, _ as HttpEntity, HttpMethod.POST, 1) >> { throw exception } + 1 * dagResourceStatistic.updateUrlTypeResourceStatus(executionId, taskName, _, errorResponse) + thrown(TaskException) + } + + def "should handle form-urlencoded POST request"() { + given: + def executionId = "exec-123" + def taskName = "testTask" + def resource = Mock(Resource) + def input = [key: "value"] + def taskInfo = new TaskInfo(name: taskName, task: new FunctionTask(taskName, null, null, "function", null, false, null, null, null, null, null, null, null, null, null, null, null, null, "POST", false, null, null, null, null, null, null)) + def headers = new LinkedMultiValueMap() + headers.add(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_FORM_URLENCODED_VALUE) + def dispatchInfo = Mock(DispatchInfo) { + getExecutionId() >> executionId + getInput() >> input + getTaskInfo() >> taskInfo + getHeaders() >> headers + } + def requestParams = Mock(HttpParameter) { + getBody() >> [stringParam: "test", mapParam: [key: "value"], listParam: ["item1"]] + } + def url = "http://test.com/api" + def expectedResponse = "success" + + when: + def result = dispatcher.handle(resource, dispatchInfo) + + then: + 1 * switcherManager.getSwitcherState("ENABLE_FUNCTION_DISPATCH_RET_CHECK") >> false + 1 * httpInvokeHelper.functionRequestParams(executionId, taskName, resource, input) >> requestParams + 1 * httpInvokeHelper.buildUrl(resource, requestParams.queryParams) >> url + 1 * httpInvokeHelper.invokeRequest(executionId, taskName, url, _ as HttpEntity, HttpMethod.POST, 1) >> expectedResponse + 1 * dagResourceStatistic.updateUrlTypeResourceStatus(executionId, taskName, _, expectedResponse) + result == expectedResponse + } }