Skip to content

Commit b250a57

Browse files
rgambeegithub-actions[bot]
authored andcommitted
Integrate Tom's SOTA forecaster (#5477)
This integrates Tom's SOTA forecaster into the app. Users can toggle between it and the existing forecaster by setting the effort level parameter when using the SDK or MCP server. I tried to keep the implementation faithful to Tom's prototypes. We run three research agents on each question. Then we refine those three estimates with an additional LLM call to produce a final forecast (no further research). Lastly, we produce a user-facing summary. A potential future optimization would be to combine the refinement and summary steps into one, which would save time. The model selections are the same as Tom settled on. The prompts are also the same except for some slight wording changes. Most notable is adding units to the prompts for numeric forecasts. Based on my limited testing, the cost is about $0.75 per row, but it would be good to average across a wider range of tasks. Sourced from commit e26e249e30afd2673271297dcc11e4e065f634b7
1 parent ef3eddf commit b250a57

7 files changed

Lines changed: 52 additions & 13 deletions

File tree

futuresearch-mcp/src/futuresearch_mcp/models.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from futuresearch.generated.models.dedupe_operation_strategy import (
1010
DedupeOperationStrategy,
1111
)
12+
from futuresearch.generated.models.forecast_effort_level import ForecastEffortLevel
1213
from futuresearch.generated.models.llm_enum_public import LLMEnumPublic
1314
from futuresearch.task import EffortLevel
1415
from jsonschema import SchemaError
@@ -420,6 +421,10 @@ class ForecastInput(_SingleSourceInput):
420421
"as YYYY-MM-DD strings for timing questions like 'When will X happen?'. "
421422
"Requires output_field when 'numeric' or 'date'.",
422423
)
424+
effort_level: ForecastEffortLevel | None = Field(
425+
default=None,
426+
description="Affects accuracy and cost of forecast. Default: low.",
427+
)
423428
output_field: str | None = Field(
424429
default=None,
425430
description="Name of the numeric quantity being forecast (e.g. 'price', 'count'). "

futuresearch-mcp/src/futuresearch_mcp/tools.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -696,6 +696,7 @@ async def futuresearch_forecast(
696696
session=session,
697697
input=input_data,
698698
forecast_type=params.forecast_type,
699+
effort_level=params.effort_level,
699700
output_field=params.output_field,
700701
units=params.units,
701702
)

src/futuresearch/generated/models/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,11 @@
2323
from .dedupe_operation_strategy import DedupeOperationStrategy
2424
from .error_response import ErrorResponse
2525
from .error_response_details_type_0 import ErrorResponseDetailsType0
26+
from .forecast_effort_level import ForecastEffortLevel
2627
from .forecast_operation import ForecastOperation
27-
from .forecast_operation_forecast_type import ForecastOperationForecastType
2828
from .forecast_operation_input_type_1_item import ForecastOperationInputType1Item
2929
from .forecast_operation_input_type_2 import ForecastOperationInputType2
30+
from .forecast_type import ForecastType
3031
from .health_response import HealthResponse
3132
from .http_validation_error import HTTPValidationError
3233
from .insufficient_balance_response import InsufficientBalanceResponse
@@ -108,10 +109,11 @@
108109
"DedupeOperationStrategy",
109110
"ErrorResponse",
110111
"ErrorResponseDetailsType0",
112+
"ForecastEffortLevel",
111113
"ForecastOperation",
112-
"ForecastOperationForecastType",
113114
"ForecastOperationInputType1Item",
114115
"ForecastOperationInputType2",
116+
"ForecastType",
115117
"HealthResponse",
116118
"HTTPValidationError",
117119
"InsufficientBalanceResponse",
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from enum import Enum
2+
3+
4+
class ForecastEffortLevel(str, Enum):
5+
HIGH = "high"
6+
LOW = "low"
7+
8+
def __str__(self) -> str:
9+
return str(self.value)

src/futuresearch/generated/models/forecast_operation.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
from attrs import define as _attrs_define
88
from attrs import field as _attrs_field
99

10-
from ..models.forecast_operation_forecast_type import ForecastOperationForecastType
10+
from ..models.forecast_effort_level import ForecastEffortLevel
11+
from ..models.forecast_type import ForecastType
1112
from ..types import UNSET, Unset
1213

1314
if TYPE_CHECKING:
@@ -27,25 +28,24 @@ class ForecastOperation:
2728
of a list of JSON objects
2829
task (str): Overall context or instructions for the forecast. Each row in the input should contain the
2930
question/scenario to forecast.
30-
forecast_type (ForecastOperationForecastType): Type of forecast. 'binary': yes/no probability (0-100) for
31-
questions like 'Will X happen?'. 'numeric': percentile estimates (p10-p90) for questions like 'What will the
32-
price/value/count be?'. 'date': date percentile estimates (p10-p90) as YYYY-MM-DD strings for timing questions
33-
like 'When will X happen?'. Requires output_field when 'numeric' or 'date'.
31+
forecast_type (ForecastType):
3432
session_id (None | Unset | UUID): Session ID. If not provided, a new session is auto-created for this task.
3533
webhook_url (None | str | Unset): Optional URL to receive a POST callback when the task completes or fails.
3634
output_field (None | str | Unset): Name of the numeric quantity being forecast (e.g. 'price', 'count'). Required
3735
when forecast_type is 'numeric'. Output columns will be named {output_field}_p10 through {output_field}_p90.
3836
units (None | str | Unset): Units for the numeric forecast (e.g. 'USD per barrel', 'thousands'). Required when
3937
forecast_type is 'numeric'.
38+
effort_level (ForecastEffortLevel | Unset):
4039
"""
4140

4241
input_: ForecastOperationInputType2 | list[ForecastOperationInputType1Item] | UUID
4342
task: str
44-
forecast_type: ForecastOperationForecastType
43+
forecast_type: ForecastType
4544
session_id: None | Unset | UUID = UNSET
4645
webhook_url: None | str | Unset = UNSET
4746
output_field: None | str | Unset = UNSET
4847
units: None | str | Unset = UNSET
48+
effort_level: ForecastEffortLevel | Unset = UNSET
4949
additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict)
5050

5151
def to_dict(self) -> dict[str, Any]:
@@ -91,6 +91,10 @@ def to_dict(self) -> dict[str, Any]:
9191
else:
9292
units = self.units
9393

94+
effort_level: str | Unset = UNSET
95+
if not isinstance(self.effort_level, Unset):
96+
effort_level = self.effort_level.value
97+
9498
field_dict: dict[str, Any] = {}
9599
field_dict.update(self.additional_properties)
96100
field_dict.update(
@@ -108,6 +112,8 @@ def to_dict(self) -> dict[str, Any]:
108112
field_dict["output_field"] = output_field
109113
if units is not UNSET:
110114
field_dict["units"] = units
115+
if effort_level is not UNSET:
116+
field_dict["effort_level"] = effort_level
111117

112118
return field_dict
113119

@@ -150,7 +156,7 @@ def _parse_input_(data: object) -> ForecastOperationInputType2 | list[ForecastOp
150156

151157
task = d.pop("task")
152158

153-
forecast_type = ForecastOperationForecastType(d.pop("forecast_type"))
159+
forecast_type = ForecastType(d.pop("forecast_type"))
154160

155161
def _parse_session_id(data: object) -> None | Unset | UUID:
156162
if data is None:
@@ -196,6 +202,13 @@ def _parse_units(data: object) -> None | str | Unset:
196202

197203
units = _parse_units(d.pop("units", UNSET))
198204

205+
_effort_level = d.pop("effort_level", UNSET)
206+
effort_level: ForecastEffortLevel | Unset
207+
if isinstance(_effort_level, Unset):
208+
effort_level = UNSET
209+
else:
210+
effort_level = ForecastEffortLevel(_effort_level)
211+
199212
forecast_operation = cls(
200213
input_=input_,
201214
task=task,
@@ -204,6 +217,7 @@ def _parse_units(data: object) -> None | str | Unset:
204217
webhook_url=webhook_url,
205218
output_field=output_field,
206219
units=units,
220+
effort_level=effort_level,
207221
)
208222

209223
forecast_operation.additional_properties = d

src/futuresearch/generated/models/forecast_operation_forecast_type.py renamed to src/futuresearch/generated/models/forecast_type.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from enum import Enum
22

33

4-
class ForecastOperationForecastType(str, Enum):
4+
class ForecastType(str, Enum):
55
BINARY = "binary"
66
DATE = "date"
77
NUMERIC = "numeric"

src/futuresearch/ops.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,10 @@
2727
DedupeOperation,
2828
DedupeOperationInputType1Item,
2929
DedupeOperationStrategy,
30+
ForecastEffortLevel,
3031
ForecastOperation,
31-
ForecastOperationForecastType,
3232
ForecastOperationInputType1Item,
33+
ForecastType,
3334
LLMEnumPublic,
3435
MergeOperation,
3536
MergeOperationLeftInputType1Item,
@@ -832,6 +833,7 @@ async def forecast(
832833
session: Session | None = None,
833834
*,
834835
forecast_type: Literal["binary", "numeric", "date"],
836+
effort_level: ForecastEffortLevel | None = None,
835837
output_field: str | None = None,
836838
units: str | None = None,
837839
) -> TableResult:
@@ -867,6 +869,7 @@ async def forecast(
867869
session: Optional session. If not provided, one will be created automatically.
868870
forecast_type: ``"binary"`` for probability forecasts, ``"numeric"`` for
869871
percentile estimates, ``"date"`` for date percentile estimates.
872+
effort_level: affects accuracy and cost of forecast. Default: low.
870873
output_field: Name of the quantity being forecast (required for numeric
871874
and date, e.g. ``"price"``, ``"launch_date"``).
872875
units: Units for numeric forecasts (e.g. ``"USD per barrel"``).
@@ -883,6 +886,7 @@ async def forecast(
883886
session=internal_session,
884887
input=input,
885888
forecast_type=forecast_type,
889+
effort_level=effort_level,
886890
output_field=output_field,
887891
units=units,
888892
)
@@ -908,7 +912,9 @@ async def forecast_async(
908912
task: str,
909913
session: Session,
910914
input: DataFrame | UUID | TableResult,
915+
*,
911916
forecast_type: Literal["binary", "numeric", "date"],
917+
effort_level: ForecastEffortLevel | None = None,
912918
output_field: str | None = None,
913919
units: str | None = None,
914920
) -> EveryrowTask[BaseModel]:
@@ -920,6 +926,7 @@ async def forecast_async(
920926
input: Input data.
921927
forecast_type: ``"binary"`` for yes/no probability, ``"numeric"`` for
922928
percentile estimates, ``"date"`` for date percentile estimates.
929+
effort_level: affects accuracy and cost of forecast. Default: low.
923930
output_field: Name of the quantity (required for numeric and date).
924931
units: Units for numeric forecasts (required for numeric).
925932
@@ -929,10 +936,11 @@ async def forecast_async(
929936
input_data = _prepare_table_input(input, ForecastOperationInputType1Item)
930937

931938
body = ForecastOperation(
932-
input_=input_data, # type: ignore
939+
input_=input_data,
933940
task=task,
934941
session_id=session.session_id,
935-
forecast_type=ForecastOperationForecastType(forecast_type),
942+
forecast_type=ForecastType(forecast_type),
943+
effort_level=effort_level if effort_level is not None else UNSET,
936944
output_field=output_field,
937945
units=units,
938946
)

0 commit comments

Comments
 (0)