diff --git a/android_world/env/android_world_controller.py b/android_world/env/android_world_controller.py index 2f313306..0cd11446 100644 --- a/android_world/env/android_world_controller.py +++ b/android_world/env/android_world_controller.py @@ -162,10 +162,12 @@ class AndroidWorldController(base_wrapper.BaseWrapper): def __init__( self, env: env_interface.AndroidEnvInterface, - a11y_method: A11yMethod = A11yMethod.A11Y_FORWARDER_APP, + a11y_method: A11yMethod | None = None, install_a11y_forwarding_app: bool = True, ): self._original_env = env + if a11y_method is None: + a11y_method = A11yMethod.A11Y_FORWARDER_APP if a11y_method == A11yMethod.A11Y_FORWARDER_APP: self._env = apply_a11y_forwarder_app_wrapper( env, install_a11y_forwarding_app @@ -201,6 +203,7 @@ def refresh_env(self): console_port=self.env._coordinator._simulator._config.emulator_launcher.emulator_console_port, adb_path=self.env._coordinator._simulator._config.adb_controller.adb_path, grpc_port=self.env._coordinator._simulator._config.emulator_launcher.grpc_port, + a11y_method=self._a11y_method, ).env # pylint: enable=protected-access # pytype: enable=attribute-error @@ -308,6 +311,7 @@ def get_controller( console_port: int = 5554, adb_path: str = DEFAULT_ADB_PATH, grpc_port: int = 8554, + a11y_method: A11yMethod | None = None, ) -> AndroidWorldController: """Creates a controller by connecting to an existing Android environment.""" @@ -326,4 +330,4 @@ def get_controller( ) android_env_instance = loader.load(config) logging.info('Setting up AndroidWorldController.') - return AndroidWorldController(android_env_instance) + return AndroidWorldController(android_env_instance, a11y_method=a11y_method) diff --git a/android_world/env/env_launcher_test.py b/android_world/env/env_launcher_test.py index 8bd5ee7f..00453684 100644 --- a/android_world/env/env_launcher_test.py +++ b/android_world/env/env_launcher_test.py @@ -56,7 +56,7 @@ def test_get_env( ), ) ) - mock_controller.assert_called_with(mock_android_env) + mock_controller.assert_called_with(mock_android_env, a11y_method=None) mock_async_android_env.assert_called_with(mock_controller.return_value)