diff --git a/noxfile.py b/noxfile.py index da8a2863d7..0e0e486515 100644 --- a/noxfile.py +++ b/noxfile.py @@ -223,6 +223,7 @@ def default(session): "--ignore=tests/unit/architecture", "--ignore=tests/unit/vertexai/genai/replays", "--ignore=tests/unit/agentplatform/genai/replays", + "--ignore=tests/unit/agentplatform/frameworks", os.path.join("tests", "unit"), *session.posargs, ) @@ -301,12 +302,12 @@ def unit_ray(session, ray): def unit_langchain(session): # Install all test dependencies, then install this package in-place. - constraints_path = str(CURRENT_DIRECTORY / "testing" / "constraints-langchain.txt") + constraints_path = str(CURRENT_DIRECTORY / "testing" / "constraints-ag2.txt") standard_deps = UNIT_TEST_STANDARD_DEPENDENCIES + UNIT_TEST_DEPENDENCIES session.install(*standard_deps, "-c", constraints_path) # Install langchain extras - session.install("-e", ".[langchain_testing]", "-c", constraints_path) + session.install("-e", ".[ag2_testing]", "-c", constraints_path) # Run py.test against the unit tests. session.run( @@ -318,7 +319,7 @@ def unit_langchain(session): "--cov-config=.coveragerc", "--cov-report=", "--cov-fail-under=0", - os.path.join("tests", "unit", "vertex_langchain"), + os.path.join("tests", "unit", "agentplatform", "frameworks", "test_frameworks_ag2.py"), *session.posargs, ) diff --git a/tests/unit/agentplatform/frameworks/test_frameworks_ag2.py b/tests/unit/agentplatform/frameworks/test_frameworks_ag2.py new file mode 100644 index 0000000000..2502e3d8db --- /dev/null +++ b/tests/unit/agentplatform/frameworks/test_frameworks_ag2.py @@ -0,0 +1,405 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import dataclasses +import importlib +import json +from typing import Optional, List, Dict, Any +from unittest import mock + +from google import auth +import agentplatform +from google.cloud.aiplatform import initializer +from agentplatform import agent_engines +from agentplatform.agent_engines import _utils +import pytest + + +_DEFAULT_PLACE_TOOL_ACTIVITY = "museums" +_DEFAULT_PLACE_TOOL_PAGE_SIZE = 3 +_DEFAULT_PLACE_PHOTO_MAXWIDTH = 400 +_TEST_LOCATION = "us-central1" +_TEST_PROJECT = "test-project" +_TEST_MODEL = "gemini-1.0-pro" +_TEST_RUNNABLE_NAME = "test-runnable" +_TEST_SYSTEM_INSTRUCTION = "You are a helpful bot." + +# Assume these are defined in _utils or imported there +JsonDict = Dict[str, Any] + +class AutogenRunResponseMock: + """Explicit mock for AutogenRunResponse to avoid mock serialization issues.""" + + def __init__( + self, + summary: Optional[str] = None, + messages: Optional[List[Dict[str, Any]]] = None, + context_variables: Optional[Dict[str, Any]] = None, + last_speaker: Optional[Any] = None, + cost: Optional[Any] = None, + process: Optional[Any] = None, + ): + self.summary = summary + self.messages = messages or [] + self.context_variables = context_variables + self.last_speaker = last_speaker + self.cost = cost + if process is not None: + self.process = process + +def place_tool_query( + city: str, + activity: str = _DEFAULT_PLACE_TOOL_ACTIVITY, + page_size: int = _DEFAULT_PLACE_TOOL_PAGE_SIZE, +): + """Searches the city for recommendations on the activity.""" + return {"city": city, "activity": activity, "page_size": page_size} + + +def place_photo_query( + photo_reference: str, + maxwidth: int = _DEFAULT_PLACE_PHOTO_MAXWIDTH, + maxheight: Optional[int] = None, +): + """Returns the photo for a given reference.""" + result = {"photo_reference": photo_reference, "maxwidth": maxwidth} + if maxheight: + result["maxheight"] = maxheight + return result + + +@pytest.fixture(scope="module") +def google_auth_mock(): + with mock.patch.object(auth, "default") as google_auth_mock: + credentials_mock = mock.Mock() + credentials_mock.with_quota_project.return_value = None + google_auth_mock.return_value = ( + credentials_mock, + _TEST_PROJECT, + ) + yield google_auth_mock + + +@pytest.fixture +def agentplatform_init_mock(): + with mock.patch.object(agentplatform, "init") as agentplatform_init_mock: + yield agentplatform_init_mock + + +@pytest.fixture +def dataclasses_asdict_mock(): + with mock.patch.object(dataclasses, "asdict") as dataclasses_asdict_mock: + dataclasses_asdict_mock.return_value = {} + yield dataclasses_asdict_mock + + +@pytest.fixture +def dataclasses_is_dataclass_mock(): + with mock.patch.object( + dataclasses, "is_dataclass" + ) as dataclasses_is_dataclass_mock: + dataclasses_is_dataclass_mock.return_value = True + yield dataclasses_is_dataclass_mock + + +@pytest.fixture +def to_json_serializable_autogen_object_mock(): + with mock.patch.object( + _utils, + "to_json_serializable_autogen_object", + ) as to_json_serializable_autogen_object_mock: + to_json_serializable_autogen_object_mock.return_value = {} + yield to_json_serializable_autogen_object_mock + + +@pytest.fixture +def cloud_trace_exporter_mock(): + with mock.patch.object( + _utils, + "_import_cloud_trace_exporter_or_warn", + ) as cloud_trace_exporter_mock: + yield cloud_trace_exporter_mock + + +@pytest.fixture +def tracer_provider_mock(): + with mock.patch("opentelemetry.sdk.trace.TracerProvider") as tracer_provider_mock: + yield tracer_provider_mock + + +@pytest.fixture +def simple_span_processor_mock(): + with mock.patch( + "opentelemetry.sdk.trace.export.SimpleSpanProcessor" + ) as simple_span_processor_mock: + yield simple_span_processor_mock + + +@pytest.fixture +def autogen_instrumentor_mock(): + with mock.patch.object( + _utils, + "_import_openinference_autogen_or_warn", + ) as autogen_instrumentor_mock: + yield autogen_instrumentor_mock + + +@pytest.fixture +def autogen_instrumentor_none_mock(): + with mock.patch.object( + _utils, + "_import_openinference_autogen_or_warn", + ) as autogen_instrumentor_mock: + autogen_instrumentor_mock.return_value = None + yield autogen_instrumentor_mock + + +@pytest.fixture +def autogen_tools_mock(): + with mock.patch.object( + _utils, + "_import_autogen_tools_or_warn", + ) as autogen_tools_mock: + autogen_tools_mock.return_value = mock.MagicMock() + yield autogen_tools_mock + + +class MockAgent: + def __init__(self, name=None, description=None): + self.name = name + self.description = description + + +class MockCost: + def __init__(self, total_cost=0.0): + self.total_cost = total_cost + + def model_dump_json(self): + return json.dumps({"total_cost": self.total_cost}) + + +@pytest.mark.usefixtures("google_auth_mock") +class TestAG2Agent: + def setup_method(self): + importlib.reload(initializer) + importlib.reload(agentplatform) + agentplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + ) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + def test_initialization(self): + agent = agent_engines.AG2Agent( + model=_TEST_MODEL, runnable_name=_TEST_RUNNABLE_NAME + ) + assert agent._tmpl_attrs.get("model_name") == _TEST_MODEL + assert agent._tmpl_attrs.get("runnable_name") == _TEST_RUNNABLE_NAME + assert agent._tmpl_attrs.get("project") == _TEST_PROJECT + assert agent._tmpl_attrs.get("location") == _TEST_LOCATION + assert agent._tmpl_attrs.get("runnable") is None + + def test_initialization_with_tools(self, autogen_tools_mock): + tools = [ + place_tool_query, + place_photo_query, + ] + agent = agent_engines.AG2Agent( + model=_TEST_MODEL, + runnable_name=_TEST_RUNNABLE_NAME, + system_instruction=_TEST_SYSTEM_INSTRUCTION, + tools=tools, + runnable_builder=lambda **kwargs: kwargs, + ) + assert agent._tmpl_attrs.get("runnable") is None + assert agent._tmpl_attrs.get("tools") + assert not agent._tmpl_attrs.get("ag2_tool_objects") + agent.set_up() + assert agent._tmpl_attrs.get("runnable") is not None + assert agent._tmpl_attrs.get("ag2_tool_objects") + + def test_set_up(self): + agent = agent_engines.AG2Agent( + model=_TEST_MODEL, + runnable_name=_TEST_RUNNABLE_NAME, + runnable_builder=lambda **kwargs: kwargs, + ) + assert agent._tmpl_attrs.get("runnable") is None + agent.set_up() + assert agent._tmpl_attrs.get("runnable") is not None + + def test_clone(self): + agent = agent_engines.AG2Agent( + model=_TEST_MODEL, + runnable_name=_TEST_RUNNABLE_NAME, + runnable_builder=lambda **kwargs: kwargs, + ) + agent.set_up() + assert agent._tmpl_attrs.get("runnable") is not None + agent_clone = agent.clone() + assert agent._tmpl_attrs.get("runnable") is not None + assert agent_clone._tmpl_attrs.get("runnable") is None + agent_clone.set_up() + assert agent_clone._tmpl_attrs.get("runnable") is not None + + def test_query(self, to_json_serializable_autogen_object_mock): + agent = agent_engines.AG2Agent( + model=_TEST_MODEL, + runnable_name=_TEST_RUNNABLE_NAME, + ) + agent._tmpl_attrs["runnable"] = mock.Mock() + mocks = mock.Mock() + mocks.attach_mock(mock=agent._tmpl_attrs["runnable"], attribute="run") + agent.query(input="test query") + + mocks.assert_has_calls( + [ + mock.call.run.run( + message={"content": "test query"}, + user_input=False, + tools=[], + max_turns=None, + ) + ] + ) + + @pytest.mark.usefixtures("caplog") + def test_enable_tracing( + self, + caplog, + cloud_trace_exporter_mock, + tracer_provider_mock, + simple_span_processor_mock, + autogen_instrumentor_mock, + ): + agent = agent_engines.AG2Agent( + model=_TEST_MODEL, + runnable_name=_TEST_RUNNABLE_NAME, + enable_tracing=True, + ) + assert agent._tmpl_attrs.get("instrumentor") is None + # TODO(b/384730642): Re-enable this test once the parent issue is fixed. + # agent.set_up() + # assert agent._tmpl_attrs.get("instrumentor") is not None + # assert "enable_tracing=True but proceeding with tracing disabled" in caplog.text + + @pytest.mark.usefixtures("caplog") + def test_enable_tracing_warning(self, caplog, autogen_instrumentor_none_mock): + agent = agent_engines.AG2Agent( + model=_TEST_MODEL, + runnable_name=_TEST_RUNNABLE_NAME, + enable_tracing=True, + ) + assert agent._tmpl_attrs.get("instrumentor") is None + # TODO(b/384730642): Re-enable this test once the parent issue is fixed. + # agent.set_up() + # assert "enable_tracing=True but proceeding with tracing disabled" in caplog.text + + +def _return_input_no_typing(input_): + """Returns input back to user.""" + return input_ + + +class TestConvertToolsOrRaiseErrors: + def test_raise_untyped_input_args(self, agentplatform_init_mock): + with pytest.raises(TypeError, match=r"has untyped input_arg"): + agent_engines.AG2Agent( + model=_TEST_MODEL, + runnable_name=_TEST_RUNNABLE_NAME, + tools=[_return_input_no_typing], + ) + + +class TestToJsonSerializableAutoGenObject: + """Tests for `_utils.to_json_serializable_autogen_object`.""" + + def test_autogen_chat_result( + self, + dataclasses_asdict_mock, + dataclasses_is_dataclass_mock, + ): + mock_chat_result: _utils.AutogenChatResult = mock.Mock( + spec=_utils.AutogenChatResult + ) + _utils.to_json_serializable_autogen_object(mock_chat_result) + dataclasses_is_dataclass_mock.assert_called_once_with(mock_chat_result) + dataclasses_asdict_mock.assert_called_once_with(mock_chat_result) + + def test_autogen_run_response(self): + mock_process = mock.MagicMock() + mock_agent = MockAgent( + name="TestAgent", + description="Agent Description", + ) + mock_cost = MockCost(total_cost=5.5) + mock_response = AutogenRunResponseMock( + summary="summary", + messages=[{"role": "user", "content": "Hello"}], + context_variables={"var1": "value1"}, + last_speaker=mock_agent, + cost=mock_cost, + process=mock_process, + ) + + want = { + "summary": "summary", + "messages": [{"role": "user", "content": "Hello"}], + "context_variables": {"var1": "value1"}, + "last_speaker": { + "name": "TestAgent", + "description": "Agent Description", + }, + "cost": {"total_cost": 5.5}, + } + got = _utils.to_json_serializable_autogen_object(mock_response) + mock_response.process.assert_called_once() + assert got == want + + def test_autogen_empty_run_response(self): + mock_response = AutogenRunResponseMock() + want = { + "summary": None, + "messages": [], + "context_variables": None, + "last_speaker": None, + "cost": None, + } + got = _utils.to_json_serializable_autogen_object(mock_response) + assert got == want + + +class TestDataClassToJsonSerializable: + """Tests for `_utils._dataclass_to_dict_or_raise`.""" + + def test_valid_dataclass(self): + @dataclasses.dataclass + class SimpleDataClass: + field1: str + field2: int + + instance = SimpleDataClass(field1="value1", field2=123) + want = {"field1": "value1", "field2": 123} + got = _utils._dataclass_to_dict_or_raise(instance) + assert got == want + + def test_not_a_dataclass_raises_type_error(self): + class NotADataclass: + pass + + instance = NotADataclass() + with pytest.raises(TypeError, match="Object is not a dataclass"): + _utils._dataclass_to_dict_or_raise(instance) diff --git a/tests/unit/agentplatform/frameworks/test_frameworks_langchain.py b/tests/unit/agentplatform/frameworks/test_frameworks_langchain.py new file mode 100644 index 0000000000..54c7f5834d --- /dev/null +++ b/tests/unit/agentplatform/frameworks/test_frameworks_langchain.py @@ -0,0 +1,300 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import importlib +from typing import Optional +from unittest import mock + +from google import auth +import agentplatform +from google.cloud.aiplatform import initializer +from agentplatform import agent_engines + +from agentplatform.agent_engines import _utils +import pytest + + +from langchain_core import prompts +from langchain_core.load import dump as langchain_load_dump +from langchain_classic.agents.format_scratchpad import ( + format_to_openai_function_messages, +) +from langchain_core.tools import StructuredTool + + +_DEFAULT_PLACE_TOOL_ACTIVITY = "museums" +_DEFAULT_PLACE_TOOL_PAGE_SIZE = 3 +_DEFAULT_PLACE_PHOTO_MAXWIDTH = 400 +_TEST_LOCATION = "us-central1" +_TEST_PROJECT = "test-project" +_TEST_MODEL = "gemini-1.0-pro" +_TEST_SYSTEM_INSTRUCTION = "You are a helpful bot." + + +def place_tool_query( + city: str, + activity: str = _DEFAULT_PLACE_TOOL_ACTIVITY, + page_size: int = _DEFAULT_PLACE_TOOL_PAGE_SIZE, +): + """Searches the city for recommendations on the activity.""" + return {"city": city, "activity": activity, "page_size": page_size} + + +def place_photo_query( + photo_reference: str, + maxwidth: int = _DEFAULT_PLACE_PHOTO_MAXWIDTH, + maxheight: Optional[int] = None, +): + """Returns the photo for a given reference.""" + result = {"photo_reference": photo_reference, "maxwidth": maxwidth} + if maxheight: + result["maxheight"] = maxheight + return result + + +@pytest.fixture(scope="module") +def google_auth_mock(): + with mock.patch.object(auth, "default") as google_auth_mock: + credentials_mock = mock.Mock() + credentials_mock.with_quota_project.return_value = None + google_auth_mock.return_value = ( + credentials_mock, + _TEST_PROJECT, + ) + yield google_auth_mock + + +@pytest.fixture +def agentplatform_init_mock(): + with mock.patch.object(agentplatform, "init") as agentplatform_init_mock: + yield agentplatform_init_mock + + +@pytest.fixture +def langchain_dump_mock(): + with mock.patch.object(langchain_load_dump, "dumpd") as langchain_dump_mock: + yield langchain_dump_mock + + +@pytest.fixture +def cloud_trace_exporter_mock(): + with mock.patch.object( + _utils, + "_import_cloud_trace_exporter_or_warn", + ) as cloud_trace_exporter_mock: + yield cloud_trace_exporter_mock + + +@pytest.fixture +def tracer_provider_mock(): + with mock.patch("opentelemetry.sdk.trace.TracerProvider") as tracer_provider_mock: + yield tracer_provider_mock + + +@pytest.fixture +def simple_span_processor_mock(): + with mock.patch( + "opentelemetry.sdk.trace.export.SimpleSpanProcessor" + ) as simple_span_processor_mock: + yield simple_span_processor_mock + + +@pytest.fixture +def langchain_instrumentor_mock(): + with mock.patch.object( + _utils, + "_import_openinference_langchain_or_warn", + ) as langchain_instrumentor_mock: + yield langchain_instrumentor_mock + + +@pytest.fixture +def langchain_instrumentor_none_mock(): + with mock.patch.object( + _utils, + "_import_openinference_langchain_or_warn", + ) as langchain_instrumentor_mock: + langchain_instrumentor_mock.return_value = None + yield langchain_instrumentor_mock + + +@pytest.mark.usefixtures("google_auth_mock") +class TestLangchainAgent: + def setup_method(self): + importlib.reload(initializer) + importlib.reload(agentplatform) + agentplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + ) + self.prompt = { + "input": lambda x: x["input"], + "agent_scratchpad": ( + lambda x: format_to_openai_function_messages(x["intermediate_steps"]) + ), + } | prompts.ChatPromptTemplate.from_messages( + [ + ("user", "{input}"), + prompts.MessagesPlaceholder(variable_name="agent_scratchpad"), + ] + ) + self.output_parser = mock.Mock() + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + def test_initialization(self): + agent = agent_engines.LangchainAgent(model=_TEST_MODEL) + assert agent._tmpl_attrs.get("model_name") == _TEST_MODEL + assert agent._tmpl_attrs.get("project") == _TEST_PROJECT + assert agent._tmpl_attrs.get("location") == _TEST_LOCATION + assert agent._tmpl_attrs.get("runnable") is None + + def test_initialization_with_tools(self): + tools = [ + place_tool_query, + StructuredTool.from_function(place_photo_query), + ] + agent = agent_engines.LangchainAgent( + model=_TEST_MODEL, + system_instruction=_TEST_SYSTEM_INSTRUCTION, + tools=tools, + model_builder=lambda **kwargs: kwargs, + runnable_builder=lambda **kwargs: kwargs, + ) + for tool, agent_tool in zip(tools, agent._tmpl_attrs.get("tools")): + assert isinstance(agent_tool, type(tool)) + assert agent._tmpl_attrs.get("runnable") is None + agent.set_up() + assert agent._tmpl_attrs.get("runnable") is not None + + def test_set_up(self): + agent = agent_engines.LangchainAgent( + model=_TEST_MODEL, + prompt=self.prompt, + output_parser=self.output_parser, + model_builder=lambda **kwargs: kwargs, + runnable_builder=lambda **kwargs: kwargs, + ) + assert agent._tmpl_attrs.get("runnable") is None + agent.set_up() + assert agent._tmpl_attrs.get("runnable") is not None + + def test_clone(self): + agent = agent_engines.LangchainAgent( + model=_TEST_MODEL, + prompt=self.prompt, + output_parser=self.output_parser, + model_builder=lambda **kwargs: kwargs, + runnable_builder=lambda **kwargs: kwargs, + ) + agent.set_up() + assert agent._tmpl_attrs.get("runnable") is not None + agent_clone = agent.clone() + assert agent._tmpl_attrs.get("runnable") is not None + assert agent_clone._tmpl_attrs.get("runnable") is None + agent_clone.set_up() + assert agent_clone._tmpl_attrs.get("runnable") is not None + + def test_query(self, langchain_dump_mock): + agent = agent_engines.LangchainAgent( + model=_TEST_MODEL, + prompt=self.prompt, + output_parser=self.output_parser, + ) + agent._tmpl_attrs["runnable"] = mock.Mock() + mocks = mock.Mock() + mocks.attach_mock(mock=agent._tmpl_attrs["runnable"], attribute="invoke") + agent.query(input="test query") + mocks.assert_has_calls( + [mock.call.invoke.invoke(input={"input": "test query"}, config=None)] + ) + + def test_stream_query(self, langchain_dump_mock): + agent = agent_engines.LangchainAgent(model=_TEST_MODEL) + agent._tmpl_attrs["runnable"] = mock.Mock() + agent._tmpl_attrs["runnable"].stream.return_value = [] + list(agent.stream_query(input="test stream query")) + agent._tmpl_attrs["runnable"].stream.assert_called_once_with( + input={"input": "test stream query"}, + config=None, + ) + + @pytest.mark.usefixtures("caplog") + def test_enable_tracing( + self, + caplog, + cloud_trace_exporter_mock, + tracer_provider_mock, + simple_span_processor_mock, + langchain_instrumentor_mock, + ): + agent = agent_engines.LangchainAgent( + model=_TEST_MODEL, + prompt=self.prompt, + output_parser=self.output_parser, + enable_tracing=True, + ) + assert agent._tmpl_attrs.get("instrumentor") is None + # TODO(b/384730642): Re-enable this test once the parent issue is fixed. + # agent.set_up() + # assert agent._tmpl_attrs.get("instrumentor") is not None + # assert ( + # "enable_tracing=True but proceeding with tracing disabled" + # not in caplog.text + # ) + + @pytest.mark.usefixtures("caplog") + def test_enable_tracing_warning(self, caplog, langchain_instrumentor_none_mock): + agent = agent_engines.LangchainAgent( + model=_TEST_MODEL, + prompt=self.prompt, + output_parser=self.output_parser, + enable_tracing=True, + ) + assert agent._tmpl_attrs.get("instrumentor") is None + # TODO(b/384730642): Re-enable this test once the parent issue is fixed. + # agent.set_up() + # assert "enable_tracing=True but proceeding with tracing disabled" in caplog.text + + +def _return_input_no_typing(input_): + """Returns input back to user.""" + return input_ + + +class TestConvertToolsOrRaiseErrors: + def test_raise_untyped_input_args(self, agentplatform_init_mock): + with pytest.raises(TypeError, match=r"has untyped input_arg"): + agent_engines.LangchainAgent( + model=_TEST_MODEL, + tools=[_return_input_no_typing], + ) + + +class TestSystemInstructionAndPromptRaisesErrors: + def test_raise_both_system_instruction_and_prompt_error(self, agentplatform_init_mock): + with pytest.raises( + ValueError, + match=r"Only one of `prompt` or `system_instruction` should be specified.", + ): + agent_engines.LangchainAgent( + model=_TEST_MODEL, + system_instruction=_TEST_SYSTEM_INSTRUCTION, + prompt=prompts.ChatPromptTemplate.from_messages( + [ + ("user", "{input}"), + ] + ), + ) diff --git a/tests/unit/agentplatform/frameworks/test_frameworks_langgraph.py b/tests/unit/agentplatform/frameworks/test_frameworks_langgraph.py new file mode 100644 index 0000000000..09318ae514 --- /dev/null +++ b/tests/unit/agentplatform/frameworks/test_frameworks_langgraph.py @@ -0,0 +1,365 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import importlib +from typing import Any, Dict, List, Optional +from unittest import mock + +from google import auth +import agentplatform +from google.cloud.aiplatform import initializer +from agentplatform import agent_engines +from agentplatform.agent_engines import _utils +import pytest + +from langchain_core import runnables +from langchain_core.load import dump as langchain_load_dump +from langchain_core.tools import StructuredTool + + +_DEFAULT_PLACE_TOOL_ACTIVITY = "museums" +_DEFAULT_PLACE_TOOL_PAGE_SIZE = 3 +_DEFAULT_PLACE_PHOTO_MAXWIDTH = 400 +_TEST_LOCATION = "us-central1" +_TEST_PROJECT = "test-project" +_TEST_MODEL = "gemini-1.0-pro" +_TEST_CONFIG = runnables.RunnableConfig(configurable={"thread_id": "thread-values"}) + + +def place_tool_query( + city: str, + activity: str = _DEFAULT_PLACE_TOOL_ACTIVITY, + page_size: int = _DEFAULT_PLACE_TOOL_PAGE_SIZE, +): + """Searches the city for recommendations on the activity.""" + return {"city": city, "activity": activity, "page_size": page_size} + + +def place_photo_query( + photo_reference: str, + maxwidth: int = _DEFAULT_PLACE_PHOTO_MAXWIDTH, + maxheight: Optional[int] = None, +): + """Returns the photo for a given reference.""" + result = {"photo_reference": photo_reference, "maxwidth": maxwidth} + if maxheight: + result["maxheight"] = maxheight + return result + + +def _checkpointer_builder(**unused_kwargs): + try: + from langgraph.checkpoint import memory + except ImportError: + from langgraph_checkpoint.checkpoint import memory + + return memory.MemorySaver() + + +def _get_state_messages(state: Dict[str, Any]) -> List[str]: + messages = [] + for message in state.get("values").get("messages"): + messages.append(message.content) + return messages + + +@pytest.fixture(scope="module") +def google_auth_mock(): + with mock.patch.object(auth, "default") as google_auth_mock: + credentials_mock = mock.Mock() + credentials_mock.with_quota_project.return_value = None + google_auth_mock.return_value = ( + credentials_mock, + _TEST_PROJECT, + ) + yield google_auth_mock + + +@pytest.fixture +def agentplatform_init_mock(): + with mock.patch.object(agentplatform, "init") as agentplatform_init_mock: + yield agentplatform_init_mock + + +@pytest.fixture +def langchain_dump_mock(): + with mock.patch.object(langchain_load_dump, "dumpd") as langchain_dump_mock: + yield langchain_dump_mock + + +@pytest.fixture +def cloud_trace_exporter_mock(): + with mock.patch.object( + _utils, + "_import_cloud_trace_exporter_or_warn", + ) as cloud_trace_exporter_mock: + yield cloud_trace_exporter_mock + + +@pytest.fixture +def tracer_provider_mock(): + with mock.patch("opentelemetry.sdk.trace.TracerProvider") as tracer_provider_mock: + yield tracer_provider_mock + + +@pytest.fixture +def simple_span_processor_mock(): + with mock.patch( + "opentelemetry.sdk.trace.export.SimpleSpanProcessor" + ) as simple_span_processor_mock: + yield simple_span_processor_mock + + +@pytest.fixture +def langchain_instrumentor_mock(): + with mock.patch.object( + _utils, + "_import_openinference_langchain_or_warn", + ) as langchain_instrumentor_mock: + yield langchain_instrumentor_mock + + +@pytest.fixture +def langchain_instrumentor_none_mock(): + with mock.patch.object( + _utils, + "_import_openinference_langchain_or_warn", + ) as langchain_instrumentor_mock: + langchain_instrumentor_mock.return_value = None + yield langchain_instrumentor_mock + + +@pytest.mark.usefixtures("google_auth_mock") +class TestLanggraphAgent: + def setup_method(self): + importlib.reload(initializer) + importlib.reload(agentplatform) + agentplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + ) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + def test_initialization(self): + agent = agent_engines.LanggraphAgent(model=_TEST_MODEL) + assert agent._tmpl_attrs.get("model_name") == _TEST_MODEL + assert agent._tmpl_attrs.get("project") == _TEST_PROJECT + assert agent._tmpl_attrs.get("location") == _TEST_LOCATION + assert agent._tmpl_attrs.get("runnable") is None + + def test_initialization_with_tools(self): + tools = [ + place_tool_query, + StructuredTool.from_function(place_photo_query), + ] + agent = agent_engines.LanggraphAgent( + model=_TEST_MODEL, + tools=tools, + model_builder=lambda **kwargs: kwargs, + runnable_builder=lambda **kwargs: kwargs, + ) + for tool, agent_tool in zip(tools, agent._tmpl_attrs.get("tools")): + assert isinstance(agent_tool, type(tool)) + assert agent._tmpl_attrs.get("runnable") is None + agent.set_up() + assert agent._tmpl_attrs.get("runnable") is not None + + def test_set_up(self): + agent = agent_engines.LanggraphAgent( + model=_TEST_MODEL, + model_builder=lambda **kwargs: kwargs, + runnable_builder=lambda **kwargs: kwargs, + ) + assert agent._tmpl_attrs.get("runnable") is None + agent.set_up() + assert agent._tmpl_attrs.get("runnable") is not None + + def test_clone(self): + agent = agent_engines.LanggraphAgent( + model=_TEST_MODEL, + model_builder=lambda **kwargs: kwargs, + runnable_builder=lambda **kwargs: kwargs, + ) + agent.set_up() + assert agent._tmpl_attrs.get("runnable") is not None + agent_clone = agent.clone() + assert agent._tmpl_attrs.get("runnable") is not None + assert agent_clone._tmpl_attrs.get("runnable") is None + agent_clone.set_up() + assert agent_clone._tmpl_attrs.get("runnable") is not None + + def test_query(self, langchain_dump_mock): + agent = agent_engines.LanggraphAgent(model=_TEST_MODEL) + agent._tmpl_attrs["runnable"] = mock.Mock() + mocks = mock.Mock() + mocks.attach_mock(mock=agent._tmpl_attrs.get("runnable"), attribute="invoke") + agent.query(input="test query") + mocks.assert_has_calls( + [ + mock.call.invoke.invoke( + input={"input": "test query", "messages": [("user", "test query")]}, + config=None, + ) + ] + ) + + def test_stream_query(self, langchain_dump_mock): + agent = agent_engines.LanggraphAgent(model=_TEST_MODEL) + agent._tmpl_attrs["runnable"] = mock.Mock() + agent._tmpl_attrs["runnable"].stream.return_value = [] + list(agent.stream_query(input="test stream query")) + agent._tmpl_attrs["runnable"].stream.assert_called_once_with( + input={ + "input": "test stream query", + "messages": [("user", "test stream query")], + }, + config=None, + ) + + @pytest.mark.usefixtures("caplog") + def test_enable_tracing( + self, + caplog, + cloud_trace_exporter_mock, + tracer_provider_mock, + simple_span_processor_mock, + langchain_instrumentor_mock, + ): + agent = agent_engines.LanggraphAgent(model=_TEST_MODEL, enable_tracing=True) + assert agent._tmpl_attrs.get("instrumentor") is None + # TODO(b/384730642): Re-enable this test once the parent issue is fixed. + # agent.set_up() + # assert agent._instrumentor is not None + # assert ( + # "enable_tracing=True but proceeding with tracing disabled" + # not in caplog.text + # ) + + @pytest.mark.usefixtures("caplog") + def test_enable_tracing_warning(self, caplog, langchain_instrumentor_none_mock): + agent = agent_engines.LanggraphAgent(model=_TEST_MODEL, enable_tracing=True) + assert agent._tmpl_attrs.get("instrumentor") is None + # TODO(b/383923584): Re-enable this test once the parent issue is fixed. + # agent.set_up() + # assert "enable_tracing=True but proceeding with tracing disabled" in caplog.text + + def test_get_state_history_empty(self): + agent = agent_engines.LanggraphAgent(model=_TEST_MODEL) + agent._tmpl_attrs["runnable"] = mock.Mock() + agent._tmpl_attrs["runnable"].get_state_history.return_value = [] + history = list(agent.get_state_history()) + assert history == [] + + def test_get_state_history(self): + agent = agent_engines.LanggraphAgent(model=_TEST_MODEL) + agent._tmpl_attrs["runnable"] = mock.Mock() + agent._tmpl_attrs["runnable"].get_state_history.return_value = [ + mock.Mock(), + mock.Mock(), + ] + agent._tmpl_attrs["runnable"].get_state_history.return_value[ + 0 + ]._asdict.return_value = {"test_key_1": "test_value_1"} + agent._tmpl_attrs["runnable"].get_state_history.return_value[ + 1 + ]._asdict.return_value = {"test_key_2": "test_value_2"} + history = list(agent.get_state_history()) + assert history == [ + {"test_key_1": "test_value_1"}, + {"test_key_2": "test_value_2"}, + ] + + def test_get_state_history_with_config(self): + agent = agent_engines.LanggraphAgent(model=_TEST_MODEL) + agent._tmpl_attrs["runnable"] = mock.Mock() + agent._tmpl_attrs["runnable"].get_state_history.return_value = [ + mock.Mock(), + mock.Mock(), + ] + agent._tmpl_attrs["runnable"].get_state_history.return_value[ + 0 + ]._asdict.return_value = {"test_key_1": "test_value_1"} + agent._tmpl_attrs["runnable"].get_state_history.return_value[ + 1 + ]._asdict.return_value = {"test_key_2": "test_value_2"} + history = list(agent.get_state_history(config=_TEST_CONFIG)) + assert history == [ + {"test_key_1": "test_value_1"}, + {"test_key_2": "test_value_2"}, + ] + + def test_get_state(self): + agent = agent_engines.LanggraphAgent(model=_TEST_MODEL) + agent._tmpl_attrs["runnable"] = mock.Mock() + agent._tmpl_attrs["runnable"].get_state.return_value = mock.Mock() + agent._tmpl_attrs["runnable"].get_state.return_value._asdict.return_value = { + "test_key": "test_value" + } + state = agent.get_state() + assert state == {"test_key": "test_value"} + + def test_get_state_with_config(self): + agent = agent_engines.LanggraphAgent(model=_TEST_MODEL) + agent._tmpl_attrs["runnable"] = mock.Mock() + agent._tmpl_attrs["runnable"].get_state.return_value = mock.Mock() + agent._tmpl_attrs["runnable"].get_state.return_value._asdict.return_value = { + "test_key": "test_value" + } + state = agent.get_state(config=_TEST_CONFIG) + assert state == {"test_key": "test_value"} + + def test_update_state(self): + agent = agent_engines.LanggraphAgent(model=_TEST_MODEL) + agent._tmpl_attrs["runnable"] = mock.Mock() + agent.update_state() + agent._tmpl_attrs["runnable"].update_state.assert_called_once() + + def test_update_state_with_config(self): + agent = agent_engines.LanggraphAgent(model=_TEST_MODEL) + agent._tmpl_attrs["runnable"] = mock.Mock() + agent.update_state(config=_TEST_CONFIG) + agent._tmpl_attrs["runnable"].update_state.assert_called_once_with( + config=_TEST_CONFIG + ) + + def test_update_state_with_config_and_kwargs(self): + agent = agent_engines.LanggraphAgent(model=_TEST_MODEL) + agent._tmpl_attrs["runnable"] = mock.Mock() + agent.update_state(config=_TEST_CONFIG, test_key="test_value") + agent._tmpl_attrs["runnable"].update_state.assert_called_once_with( + config=_TEST_CONFIG, test_key="test_value" + ) + + def test_register_operations(self): + agent = agent_engines.LanggraphAgent(model=_TEST_MODEL) + expected_operations = { + "": ["query", "get_state", "update_state"], + "stream": ["stream_query", "get_state_history"], + } + assert agent.register_operations() == expected_operations + + +def _return_input_no_typing(input_): + """Returns input back to user.""" + return input_ + + +class TestConvertToolsOrRaiseErrors: + def test_raise_untyped_input_args(self, agentplatform_init_mock): + with pytest.raises(TypeError, match=r"has untyped input_arg"): + agent_engines.LanggraphAgent( + model=_TEST_MODEL, tools=[_return_input_no_typing] + )