Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions py/src/braintrust/framework2.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,65 @@ def create(
return p


class ClassifierBuilder:
"""Builder to create a classifier in Braintrust."""

def __init__(self, project: "Project"):
self.project = project
self._task_counter = 0

def create(
self,
*,
handler: Callable[..., Any],
name: str | None = None,
slug: str | None = None,
description: str | None = None,
parameters: Any = None,
returns: Any = None,
if_exists: IfExists | None = None,
metadata: Metadata | None = None,
tags: Sequence[str] | None = None,
) -> CodeFunction:
"""Creates a classifier.

Args:
handler: The function that is called when the classifier is used.
name: The name of the classifier.
slug: A unique identifier for the classifier.
description: The description of the classifier.
parameters: The classifier's input schema, as a Pydantic model.
returns: The classifier's output schema, as a Pydantic model.
if_exists: What to do if the classifier already exists.
metadata: Custom metadata to attach to the classifier.
tags: A list of tags for the classifier.
"""
self._task_counter += 1
if name is None or len(name) == 0:
if handler.__name__ and handler.__name__ != "<lambda>":
name = handler.__name__
else:
name = f"Classifier {self._task_counter}"
if slug is None or len(slug) == 0:
slug = slugify.slugify(name)

f = CodeFunction(
project=self.project,
handler=handler,
name=name,
slug=slug,
type_="classifier",
description=description,
parameters=parameters,
returns=returns,
if_exists=if_exists,
metadata=metadata,
tags=tags,
)
self.project.add_code_function(f)
return f


class Project:
"""A handle to a Braintrust project."""

Expand All @@ -553,6 +612,7 @@ def __init__(self, name: str):
self.prompts = PromptBuilder(self)
self.parameters = ParametersBuilder(self)
self.scorers = ScorerBuilder(self)
self.classifiers = ClassifierBuilder(self)

self._publishable_code_functions: list[CodeFunction] = []
self._publishable_prompts: list[CodePrompt] = []
Expand Down
47 changes: 46 additions & 1 deletion py/src/braintrust/test_framework2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

import pytest

from .framework2 import ProjectIdCache, projects
from .framework import _set_lazy_load
from .framework2 import ProjectIdCache, global_, projects


# Check if pydantic is available
Expand Down Expand Up @@ -145,6 +146,50 @@ def test_llm_scorer_with_metadata(self):
assert scorer.name == "llm-scorer"


class TestClassifierFunctionPush:
"""Tests for classifier function push registration."""

def test_code_classifier_registers_as_code_function(self):
project = projects.create("test-project")

def classify(output: str):
return {"name": "category", "id": output}

classifier = project.classifiers.create(
handler=classify,
name="test-classifier",
)

assert classifier.type_ == "classifier"
assert classifier.name == "test-classifier"
assert classifier.slug == "test-classifier"
assert project._publishable_code_functions == [classifier]

def test_lazy_code_classifier_registration_uses_functions_registry(self):
global_.functions.clear()
global_.prompts.clear()
global_.parameters.clear()

try:
with _set_lazy_load(True):
project = projects.create("test-project")

def classify(output: str):
return {"name": "category", "id": output}

classifier = project.classifiers.create(
handler=classify,
name="test-classifier",
)

assert global_.functions == [classifier]
assert global_.functions[0].type_ == "classifier"
finally:
global_.functions.clear()
global_.prompts.clear()
global_.parameters.clear()


class TestCodeFunctionTags:
"""Tests for CodeFunction tags support."""

Expand Down
Loading