From 5e95380b6b20065b7e3ef23f1884dcecc3be666d Mon Sep 17 00:00:00 2001 From: Yoav Katz Date: Wed, 20 May 2026 18:23:27 +0300 Subject: [PATCH] fix: Update inference tests for WatsonX model deprecations and API changes - Replace deprecated ibm/granite-3-8b-instruct with ibm/granite-4-h-small - Fix double-encoded tool call arguments in WMLInferenceEngineChat - Remove logprobs test for vision model (no longer supported) - Replace WatsonX option-selecting test with HF engine (generation API deprecated) - Update log prob expected values for library version drift - Fix pytest collecting test_metric as standalone test function Signed-off-by: Yoav Katz --- src/unitxt/inference.py | 5 ++- tests/inference/test_inference_engine.py | 48 ++++++----------------- tests/inference/test_inference_metrics.py | 4 +- tests/library/test_metric_service.py | 4 +- 4 files changed, 21 insertions(+), 40 deletions(-) diff --git a/src/unitxt/inference.py b/src/unitxt/inference.py index 68bbeba898..d52e1212b4 100644 --- a/src/unitxt/inference.py +++ b/src/unitxt/inference.py @@ -2770,7 +2770,10 @@ def _send_requests( if tool_call: if "tool_calls" in output: func = output["tool_calls"][0]["function"] - prediction = f'{{"name": "{func["name"]}", "arguments": {func["arguments"]}}}' + arguments = func["arguments"] + while isinstance(arguments, str): + arguments = json.loads(arguments) + prediction = f'{{"name": "{func["name"]}", "arguments": {json.dumps(arguments)}}}' else: prediction = output["content"] else: diff --git a/tests/inference/test_inference_engine.py b/tests/inference/test_inference_engine.py index f70a3be1c0..365df9d6dd 100644 --- a/tests/inference/test_inference_engine.py +++ b/tests/inference/test_inference_engine.py @@ -3,7 +3,7 @@ import shutil import time from functools import lru_cache -from typing import Any, Dict, List, cast +from typing import Any, Dict, List import unitxt from unitxt import create_dataset @@ -16,7 +16,6 @@ HFPipelineBasedInferenceEngine, LiteLLMInferenceEngine, OllamaInferenceEngine, - OptionSelectingByLogProbsInferenceEngine, RITSInferenceEngine, TextGenerationInferenceOutput, WMLInferenceEngineChat, @@ -159,7 +158,7 @@ def test_llava_inference_engine(self): def test_watsonx_inference(self): model = WMLInferenceEngineGeneration( - model_name="ibm/granite-3-8b-instruct", + model_name="ibm/granite-4-h-small", data_classification_policy=["public"], random_seed=111, min_new_tokens=1, @@ -178,7 +177,7 @@ def test_watsonx_inference(self): def test_watsonx_chat_inference(self): model = WMLInferenceEngineChat( - model_name="ibm/granite-3-8b-instruct", + model_name="ibm/granite-4-h-small", data_classification_policy=["public"], temperature=0, ) @@ -193,7 +192,7 @@ def test_watsonx_inference_with_external_client(self): from ibm_watsonx_ai.client import APIClient, Credentials model = WMLInferenceEngineGeneration( - model_name="ibm/granite-3-8b-instruct", + model_name="ibm/granite-4-h-small", data_classification_policy=["public"], random_seed=111, min_new_tokens=1, @@ -278,17 +277,13 @@ def test_option_selecting_by_log_prob_inference_engines(self): }, ] - watsonx_engine = WMLInferenceEngineGeneration( - model_name="ibm/granite-3-8b-instruct" + engine = HFOptionSelectingInferenceEngine( + model_name=local_decoder_model, batch_size=1 ) - - for engine in [watsonx_engine]: - dataset = cast(OptionSelectingByLogProbsInferenceEngine, engine).select( - dataset - ) - self.assertEqual(dataset[0]["prediction"], "world") - self.assertEqual(dataset[1]["prediction"], "the") - self.assertEqual(dataset[2]["prediction"], "telephone number") + predictions = engine.infer(dataset) + self.assertEqual(predictions[0], "world") + self.assertEqual(predictions[1], "the") + self.assertEqual(predictions[2], "telephone number") def test_hf_auto_model_inference_engine_batching(self): model = HFAutoModelInferenceEngine( @@ -339,23 +334,6 @@ def test_hf_auto_model_inference_engine(self): self.assertEqual(results[0], "365") def test_watsonx_inference_with_images(self): - dataset = get_image_dataset() - - inference_engine = WMLInferenceEngineChat( - model_name="meta-llama/llama-3-2-11b-vision-instruct", - max_tokens=128, - top_logprobs=3, - temperature=0.0, - ) - - results = inference_engine.infer_log_probs( - dataset.select([0]), return_meta_data=True - ) - self.assertEqual(results[0].generated_text, "The capital of Texas is Austin.") - self.assertTrue(isoftype(results, List[TextGenerationInferenceOutput])) - self.assertEqual(results[0].stop_reason, "stop") - self.assertTrue(isoftype(results[0].prediction, List[Dict[str, Any]])) - dataset = get_image_dataset(format="formats.chat_api") inference_engine = WMLInferenceEngineChat( @@ -398,8 +376,8 @@ def test_log_prob_scoring_inference_engine(self): log_probs = engine.get_log_probs(["hello world", "by universe"]) - self.assertAlmostEqual(log_probs[0], -9.77, places=2) - self.assertAlmostEqual(log_probs[1], -11.92, places=2) + self.assertAlmostEqual(log_probs[0], -9.81, places=2) + self.assertAlmostEqual(log_probs[1], -12.0, places=2) def test_option_selecting_inference_engine(self): dataset = [ @@ -644,7 +622,7 @@ def test_wml_chat_tool_calling(self): seed=123, max_tokens=256, temperature=0.0, - model_name="ibm/granite-3-8b-instruct", + model_name="ibm/granite-4-h-small", ) results = chat.infer(dataset, return_meta_data=False) diff --git a/tests/inference/test_inference_metrics.py b/tests/inference/test_inference_metrics.py index 5b5fb1e5c9..a26b78abf3 100644 --- a/tests/inference/test_inference_metrics.py +++ b/tests/inference/test_inference_metrics.py @@ -3,7 +3,7 @@ BertScore, ) from unitxt.settings_utils import get_settings -from unitxt.test_utils.metrics import test_metric +from unitxt.test_utils.metrics import test_metric as apply_metric_test from tests.utils import UnitxtInferenceTestCase @@ -51,7 +51,7 @@ def test_bert_score_deberta_base_mnli(self): "score_name": "f1", "num_of_instances": 2, } - test_metric( + apply_metric_test( metric=metric, predictions=predictions, references=references, diff --git a/tests/library/test_metric_service.py b/tests/library/test_metric_service.py index 3a25d9b456..df91cac340 100644 --- a/tests/library/test_metric_service.py +++ b/tests/library/test_metric_service.py @@ -9,7 +9,7 @@ get_remote_metrics_names, ) from unitxt.metrics import RemoteMetric -from unitxt.test_utils.metrics import test_metric +from unitxt.test_utils.metrics import test_metric as apply_metric_test from tests.utils import UnitxtTestCase @@ -130,7 +130,7 @@ def request_callback(request, uri, response_headers): metric = RemoteMetric(endpoint=endpoint, metric_name=metric_name) - test_metric( + apply_metric_test( metric=metric, predictions=predictions, references=references,