diff --git a/openmetadata-integration-tests/src/test/java/org/openmetadata/it/tests/mcp/McpToolsValidationIT.java b/openmetadata-integration-tests/src/test/java/org/openmetadata/it/tests/mcp/McpToolsValidationIT.java index fae111ad3a19..ca871df1efcb 100644 --- a/openmetadata-integration-tests/src/test/java/org/openmetadata/it/tests/mcp/McpToolsValidationIT.java +++ b/openmetadata-integration-tests/src/test/java/org/openmetadata/it/tests/mcp/McpToolsValidationIT.java @@ -125,7 +125,34 @@ void testSearchMetadataTool() throws Exception { Map toolCall = McpTestUtils.createSearchMetadataToolCall("mcp_val_table", 5, Entity.TABLE); JsonNode result = executeToolCall(toolCall); - validateSearchMetadataResponse(result, "mcp_val_table"); + validateSearchMetadataResponse(result, "mcp_val_table", Entity.TABLE); + } + + @Test + @Order(1) + void testSearchMetadataEntityTypeFilterIsHonored() throws Exception { + // Regression test for https://github.com/open-metadata/OpenMetadata/issues/27796 + // Searching with a specific entityType must only return results of that type, not leak + // other types from the default dataAsset alias. + Map tableSearch = + McpTestUtils.createSearchMetadataToolCall("test", 10, Entity.TABLE); + JsonNode tableResult = executeToolCall(tableSearch); + validateSearchMetadataResponse(tableResult, "test", Entity.TABLE); + + Map dashboardSearch = + McpTestUtils.createSearchMetadataToolCall("test", 10, Entity.DASHBOARD); + JsonNode dashboardResult = executeToolCall(dashboardSearch); + JsonNode dashboardResponse = + OBJECT_MAPPER.readTree(dashboardResult.get("content").get(0).get("text").asText()); + dashboardResponse + .get("results") + .forEach( + r -> + assertThat(r.get("entityType").asText()) + .withFailMessage( + "Expected only %s results but got %s for entity %s", + Entity.DASHBOARD, r.get("entityType").asText(), r.get("name").asText()) + .isEqualTo(Entity.DASHBOARD)); } @Test @@ -444,6 +471,11 @@ private Map createSearchToolCallWithDeletedParam( private void validateSearchMetadataResponse(JsonNode result, String expectedQuery) throws Exception { + validateSearchMetadataResponse(result, expectedQuery, null); + } + + private void validateSearchMetadataResponse( + JsonNode result, String expectedQuery, String expectedEntityType) throws Exception { assertThat(result.has("content")).isTrue(); JsonNode content = result.get("content"); assertThat(content.isArray()).isTrue(); @@ -468,6 +500,13 @@ private void validateSearchMetadataResponse(JsonNode result, String expectedQuer .withFailMessage( "Missing 'deleted' field in search result for: " + r.get("name")) .isTrue(); + if (expectedEntityType != null) { + assertThat(r.get("entityType").asText()) + .withFailMessage( + "Expected entityType '%s' but got '%s' for result '%s'", + expectedEntityType, r.get("entityType").asText(), r.get("name").asText()) + .isEqualTo(expectedEntityType); + } matchingEntities.add(r.get("name").asText()); }); diff --git a/openmetadata-mcp/src/main/java/org/openmetadata/mcp/tools/SearchMetadataTool.java b/openmetadata-mcp/src/main/java/org/openmetadata/mcp/tools/SearchMetadataTool.java index 1f6b20d652ba..287c37991fdd 100644 --- a/openmetadata-mcp/src/main/java/org/openmetadata/mcp/tools/SearchMetadataTool.java +++ b/openmetadata-mcp/src/main/java/org/openmetadata/mcp/tools/SearchMetadataTool.java @@ -5,6 +5,7 @@ import static org.openmetadata.service.security.DefaultAuthorizer.getSubjectContext; import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.node.ArrayNode; import com.fasterxml.jackson.databind.node.ObjectNode; import com.google.common.annotations.VisibleForTesting; import jakarta.ws.rs.core.Response; @@ -192,6 +193,8 @@ public Map execute( LOG.debug("Applied query filter to query: {}", queryFilter); } + queryFilter = addEntityTypeFilter(queryFilter, entityType); + LOG.info( "Search query: {}, index: {}, limit: {}, includeDeleted: {}", queryFilter, @@ -199,37 +202,26 @@ public Map execute( size, includeDeleted); - SearchRequest searchRequest; + boolean userProvidedQueryFilter = params.containsKey("queryFilter"); + SearchRequest searchRequest = + new SearchRequest() + .withIndex(Entity.getSearchRepository().getIndexOrAliasName(index)) + .withSize(size) + .withFrom(from) + .withFetchSource(true) + .withDeleted(includeDeleted); if (!nullOrEmpty(queryFilter)) { - // When queryFilter is provided, use it directly as it's already a transformed OpenSearch - // query - searchRequest = - new SearchRequest() - .withIndex(Entity.getSearchRepository().getIndexOrAliasName(index)) - .withQueryFilter(queryFilter) - .withSize(size) - .withFrom(from) - .withFetchSource(true) - .withDeleted(includeDeleted); - } else { - // Fallback to basic query when no queryFilter is provided - searchRequest = - new SearchRequest() - .withQuery(query) - .withIndex(Entity.getSearchRepository().getIndexOrAliasName(index)) - .withSize(size) - .withFrom(from) - .withFetchSource(true) - .withDeleted(includeDeleted); + searchRequest.withQueryFilter(queryFilter); + } + if (!userProvidedQueryFilter) { + searchRequest.withQuery(query); } SubjectContext subjectContext = getSubjectContext(securityContext); Response response; - if (!nullOrEmpty(queryFilter)) { - // Use direct query method when queryFilter is provided since it's already a transformed query + if (userProvidedQueryFilter) { response = Entity.getSearchRepository().searchWithDirectQuery(searchRequest, subjectContext); } else { - // Use regular search for basic queries response = Entity.getSearchRepository().search(searchRequest, subjectContext); } @@ -383,6 +375,66 @@ public static Map cleanSearchResponseObject(Map return object; } + /** + * Ensures results are constrained to the requested entityType by injecting an explicit + * {@code term} filter on the {@code entityType} field. Targeting an alias by itself is not + * always sufficient — for example, when the alias resolves to a multi-entity index or fans + * out to {@code dataAsset} — so the request can leak documents of other types. Adding the + * term filter guarantees correctness regardless of how the index alias resolves. + * + * @param existingFilter user-provided OpenSearch query JSON, already wrapped under "query", or + * {@code null} + * @param entityType requested entity type, or {@code null}/blank to leave the filter untouched + * @return a JSON string containing the merged query filter, or {@code existingFilter} if no + * entityType was provided + */ + @VisibleForTesting + static String addEntityTypeFilter(String existingFilter, String entityType) { + if (entityType == null || entityType.isBlank()) { + return existingFilter; + } + ObjectNode termFilter = JsonUtils.getObjectMapper().createObjectNode(); + termFilter.putObject("term").put("entityType", entityType); + if (nullOrEmpty(existingFilter)) { + ObjectNode bool = JsonUtils.getObjectMapper().createObjectNode(); + ArrayNode filterArray = bool.putObject("bool").putArray("filter"); + filterArray.add(termFilter); + ObjectNode wrapper = JsonUtils.getObjectMapper().createObjectNode(); + wrapper.set("query", bool); + return JsonUtils.pojoToJson(wrapper); + } + try { + JsonNode root = JsonUtils.getObjectMapper().readTree(existingFilter); + JsonNode queryNode = root.get("query"); + if (queryNode == null || !queryNode.isObject()) { + return existingFilter; + } + ObjectNode queryObject = (ObjectNode) queryNode; + ObjectNode boolNode; + if (queryObject.has("bool") && queryObject.get("bool").isObject()) { + boolNode = (ObjectNode) queryObject.get("bool"); + } else { + ObjectNode originalCopy = queryObject.deepCopy(); + queryObject.removeAll(); + boolNode = queryObject.putObject("bool"); + boolNode.putArray("must").add(originalCopy); + } + ArrayNode filterArray; + if (boolNode.has("filter") && boolNode.get("filter").isArray()) { + filterArray = (ArrayNode) boolNode.get("filter"); + } else { + filterArray = boolNode.putArray("filter"); + } + filterArray.add(termFilter); + return JsonUtils.pojoToJson(root); + } catch (IOException e) { + LOG.warn( + "Unable to merge entityType filter into provided queryFilter, leaving it unchanged: {}", + e.getMessage()); + return existingFilter; + } + } + /** * Truncates aggregation buckets to prevent excessive response size that could overwhelm LLM * context windows. Based on industry best practices, LLM performance degrades when context diff --git a/openmetadata-mcp/src/test/java/org/openmetadata/mcp/tools/SearchMetadataToolTest.java b/openmetadata-mcp/src/test/java/org/openmetadata/mcp/tools/SearchMetadataToolTest.java index 6ebd351c23fa..48328fa34264 100644 --- a/openmetadata-mcp/src/test/java/org/openmetadata/mcp/tools/SearchMetadataToolTest.java +++ b/openmetadata-mcp/src/test/java/org/openmetadata/mcp/tools/SearchMetadataToolTest.java @@ -2,11 +2,14 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import com.fasterxml.jackson.databind.JsonNode; import jakarta.ws.rs.core.Response; import java.security.Principal; import java.util.HashMap; @@ -15,9 +18,14 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; import org.mockito.MockedStatic; import org.mockito.junit.jupiter.MockitoExtension; +import org.mockito.junit.jupiter.MockitoSettings; +import org.mockito.quality.Strictness; import org.openmetadata.schema.entity.teams.User; +import org.openmetadata.schema.search.SearchRequest; +import org.openmetadata.schema.utils.JsonUtils; import org.openmetadata.service.Entity; import org.openmetadata.service.search.SearchRepository; import org.openmetadata.service.security.Authorizer; @@ -35,6 +43,7 @@ * - Response formatting and structure */ @ExtendWith(MockitoExtension.class) +@MockitoSettings(strictness = Strictness.LENIENT) class SearchMetadataToolTest { private SearchMetadataTool searchMetadataTool; @@ -135,4 +144,122 @@ void testSearchWithSpecificEntityType() throws Exception { assertNotNull(result.get("results")); } } + + @Test + void testEntityTypeIsAlwaysAppliedAsExplicitFilter() throws Exception { + try (MockedStatic subjectCacheMock = mockStatic(SubjectCache.class)) { + subjectCacheMock.when(() -> SubjectCache.getUserContext("test-user")).thenReturn(mockUser); + + Map params = new HashMap<>(); + params.put("query", "test"); + params.put("entityType", "metric"); + + when(searchRepository.getIndexOrAliasName("metric")).thenReturn("metric"); + + Response mockResponse = mock(Response.class); + when(mockResponse.getEntity()).thenReturn("{\"hits\":{\"hits\":[],\"total\":{\"value\":0}}}"); + when(searchRepository.search(any(), any(SubjectContext.class))).thenReturn(mockResponse); + + searchMetadataTool.execute(authorizer, securityContext, params); + + ArgumentCaptor captor = ArgumentCaptor.forClass(SearchRequest.class); + verify(searchRepository).search(captor.capture(), any(SubjectContext.class)); + + SearchRequest sent = captor.getValue(); + assertEquals("test", sent.getQuery()); + assertNotNull(sent.getQueryFilter()); + JsonNode filter = JsonUtils.readTree(sent.getQueryFilter()); + assertEquals("metric", filter.at("/query/bool/filter/0/term/entityType").asText()); + } + } + + @Test + void testNoEntityTypeLeavesQueryFilterEmpty() throws Exception { + try (MockedStatic subjectCacheMock = mockStatic(SubjectCache.class)) { + subjectCacheMock.when(() -> SubjectCache.getUserContext("test-user")).thenReturn(mockUser); + + Map params = new HashMap<>(); + params.put("query", "test"); + + when(searchRepository.getIndexOrAliasName("dataAsset")).thenReturn("dataAsset"); + + Response mockResponse = mock(Response.class); + when(mockResponse.getEntity()).thenReturn("{\"hits\":{\"hits\":[],\"total\":{\"value\":0}}}"); + when(searchRepository.search(any(), any(SubjectContext.class))).thenReturn(mockResponse); + + searchMetadataTool.execute(authorizer, securityContext, params); + + ArgumentCaptor captor = ArgumentCaptor.forClass(SearchRequest.class); + verify(searchRepository).search(captor.capture(), any(SubjectContext.class)); + + SearchRequest sent = captor.getValue(); + assertEquals("test", sent.getQuery()); + assertNull(sent.getQueryFilter()); + } + } + + @Test + void testAddEntityTypeFilterWithoutExistingFilter() throws Exception { + String result = SearchMetadataTool.addEntityTypeFilter(null, "metric"); + assertNotNull(result); + JsonNode root = JsonUtils.readTree(result); + assertEquals("metric", root.at("/query/bool/filter/0/term/entityType").asText()); + } + + @Test + void testAddEntityTypeFilterWrapsExistingBoolQuery() throws Exception { + String existing = "{\"query\":{\"bool\":{\"must\":[{\"term\":{\"owners.name\":\"team\"}}]}}}"; + String result = SearchMetadataTool.addEntityTypeFilter(existing, "topic"); + JsonNode root = JsonUtils.readTree(result); + assertEquals("topic", root.at("/query/bool/filter/0/term/entityType").asText()); + assertEquals("team", root.at("/query/bool/must/0/term/owners.name").asText()); + } + + @Test + void testAddEntityTypeFilterWrapsNonBoolQuery() throws Exception { + String existing = "{\"query\":{\"term\":{\"owners.name\":\"team\"}}}"; + String result = SearchMetadataTool.addEntityTypeFilter(existing, "pipeline"); + JsonNode root = JsonUtils.readTree(result); + assertEquals("pipeline", root.at("/query/bool/filter/0/term/entityType").asText()); + assertEquals("team", root.at("/query/bool/must/0/term/owners.name").asText()); + } + + @Test + void testAddEntityTypeFilterIsNoopWhenEntityTypeMissing() { + String existing = "{\"query\":{\"term\":{\"owners.name\":\"team\"}}}"; + assertEquals(existing, SearchMetadataTool.addEntityTypeFilter(existing, null)); + assertEquals(existing, SearchMetadataTool.addEntityTypeFilter(existing, "")); + assertEquals(existing, SearchMetadataTool.addEntityTypeFilter(existing, " ")); + assertNull(SearchMetadataTool.addEntityTypeFilter(null, null)); + } + + @Test + void testEntityTypeFilterMergesWithUserQueryFilter() throws Exception { + try (MockedStatic subjectCacheMock = mockStatic(SubjectCache.class)) { + subjectCacheMock.when(() -> SubjectCache.getUserContext("test-user")).thenReturn(mockUser); + + Map params = new HashMap<>(); + params.put("entityType", "metric"); + params.put( + "queryFilter", + "{\"query\":{\"bool\":{\"must\":[{\"term\":{\"owners.name\":\"finance\"}}]}}}"); + + when(searchRepository.getIndexOrAliasName("metric")).thenReturn("metric"); + + Response mockResponse = mock(Response.class); + when(mockResponse.getEntity()).thenReturn("{\"hits\":{\"hits\":[],\"total\":{\"value\":0}}}"); + when(searchRepository.searchWithDirectQuery(any(), any(SubjectContext.class))) + .thenReturn(mockResponse); + + searchMetadataTool.execute(authorizer, securityContext, params); + + ArgumentCaptor captor = ArgumentCaptor.forClass(SearchRequest.class); + verify(searchRepository).searchWithDirectQuery(captor.capture(), any(SubjectContext.class)); + + SearchRequest sent = captor.getValue(); + JsonNode filter = JsonUtils.readTree(sent.getQueryFilter()); + assertEquals("metric", filter.at("/query/bool/filter/0/term/entityType").asText()); + assertEquals("finance", filter.at("/query/bool/must/0/term/owners.name").asText()); + } + } }