diff --git a/py/src/braintrust/framework2.py b/py/src/braintrust/framework2.py index 781dc2ca..c5669e8f 100644 --- a/py/src/braintrust/framework2.py +++ b/py/src/braintrust/framework2.py @@ -544,6 +544,65 @@ def create( return p +class ClassifierBuilder: + """Builder to create a classifier in Braintrust.""" + + def __init__(self, project: "Project"): + self.project = project + self._task_counter = 0 + + def create( + self, + *, + handler: Callable[..., Any], + name: str | None = None, + slug: str | None = None, + description: str | None = None, + parameters: Any = None, + returns: Any = None, + if_exists: IfExists | None = None, + metadata: Metadata | None = None, + tags: Sequence[str] | None = None, + ) -> CodeFunction: + """Creates a classifier. + + Args: + handler: The function that is called when the classifier is used. + name: The name of the classifier. + slug: A unique identifier for the classifier. + description: The description of the classifier. + parameters: The classifier's input schema, as a Pydantic model. + returns: The classifier's output schema, as a Pydantic model. + if_exists: What to do if the classifier already exists. + metadata: Custom metadata to attach to the classifier. + tags: A list of tags for the classifier. + """ + self._task_counter += 1 + if name is None or len(name) == 0: + if handler.__name__ and handler.__name__ != "": + name = handler.__name__ + else: + name = f"Classifier {self._task_counter}" + if slug is None or len(slug) == 0: + slug = slugify.slugify(name) + + f = CodeFunction( + project=self.project, + handler=handler, + name=name, + slug=slug, + type_="classifier", + description=description, + parameters=parameters, + returns=returns, + if_exists=if_exists, + metadata=metadata, + tags=tags, + ) + self.project.add_code_function(f) + return f + + class Project: """A handle to a Braintrust project.""" @@ -553,6 +612,7 @@ def __init__(self, name: str): self.prompts = PromptBuilder(self) self.parameters = ParametersBuilder(self) self.scorers = ScorerBuilder(self) + self.classifiers = ClassifierBuilder(self) self._publishable_code_functions: list[CodeFunction] = [] self._publishable_prompts: list[CodePrompt] = [] diff --git a/py/src/braintrust/test_framework2.py b/py/src/braintrust/test_framework2.py index d8745a4a..6ae930ba 100644 --- a/py/src/braintrust/test_framework2.py +++ b/py/src/braintrust/test_framework2.py @@ -5,7 +5,8 @@ import pytest -from .framework2 import ProjectIdCache, projects +from .framework import _set_lazy_load +from .framework2 import ProjectIdCache, global_, projects # Check if pydantic is available @@ -145,6 +146,50 @@ def test_llm_scorer_with_metadata(self): assert scorer.name == "llm-scorer" +class TestClassifierFunctionPush: + """Tests for classifier function push registration.""" + + def test_code_classifier_registers_as_code_function(self): + project = projects.create("test-project") + + def classify(output: str): + return {"name": "category", "id": output} + + classifier = project.classifiers.create( + handler=classify, + name="test-classifier", + ) + + assert classifier.type_ == "classifier" + assert classifier.name == "test-classifier" + assert classifier.slug == "test-classifier" + assert project._publishable_code_functions == [classifier] + + def test_lazy_code_classifier_registration_uses_functions_registry(self): + global_.functions.clear() + global_.prompts.clear() + global_.parameters.clear() + + try: + with _set_lazy_load(True): + project = projects.create("test-project") + + def classify(output: str): + return {"name": "category", "id": output} + + classifier = project.classifiers.create( + handler=classify, + name="test-classifier", + ) + + assert global_.functions == [classifier] + assert global_.functions[0].type_ == "classifier" + finally: + global_.functions.clear() + global_.prompts.clear() + global_.parameters.clear() + + class TestCodeFunctionTags: """Tests for CodeFunction tags support."""