Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion src/project_x_py/client/trading.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ async def main():

import datetime
import logging
from dataclasses import fields
from datetime import timedelta
from typing import Any

Expand All @@ -77,6 +78,14 @@ async def main():

logger = logging.getLogger(__name__)

_POSITION_FIELDS = frozenset(field.name for field in fields(Position))


def _position_from_response(data: dict[str, Any]) -> Position:
return Position(
**{field: data[field] for field in _POSITION_FIELDS if field in data}
)


class TradingMixin:
"""Mixin class providing trading functionality."""
Expand Down Expand Up @@ -182,7 +191,7 @@ async def search_open_positions(
else:
return []

return [Position(**pos) for pos in positions_data]
return [_position_from_response(pos) for pos in positions_data]

async def search_trades(
self,
Expand Down
6 changes: 4 additions & 2 deletions src/project_x_py/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,7 @@ class Position:
0=UNDEFINED, 1=LONG, 2=SHORT
size (int): Position size (number of contracts, always positive)
averagePrice (float): Average entry price of the position
contractDisplayName (Optional[str]): Human-readable contract display name

Note:
This model contains only the fields returned by ProjectX API.
Expand All @@ -385,11 +386,12 @@ class Position:
type: int
size: int
averagePrice: float
contractDisplayName: str | None = None

# Allow dict-like access for compatibility in tests/utilities
def __getitem__(self, key: str) -> Union[int, str, float]:
def __getitem__(self, key: str) -> Union[int, str, float, None]:
value = getattr(self, key)
if isinstance(value, int | str | float):
if value is None or isinstance(value, int | str | float):
return value
else:
raise TypeError(
Expand Down
1 change: 1 addition & 0 deletions src/project_x_py/types/api_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ class PositionResponse(TypedDict):
type: int # 0=UNDEFINED, 1=LONG, 2=SHORT
size: int
averagePrice: float
contractDisplayName: NotRequired[str]


class TradeResponse(TypedDict):
Expand Down
39 changes: 39 additions & 0 deletions tests/client/test_trading_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,45 @@ async def test_search_open_positions_success(self, trading_client):
assert positions[1].type == 2 # SHORT
trading_client._ensure_authenticated.assert_called_once()

@pytest.mark.asyncio
async def test_search_open_positions_preserves_display_name_and_ignores_unknown_fields(
self, trading_client
):
"""Test position search keeps known display fields and ignores unknown ones."""
trading_client.account_info = Account(
id=12345,
name="Test Account",
balance=10000.0,
canTrade=True,
isVisible=True,
simulated=False,
)

mock_response = {
"success": True,
"positions": [
{
"id": "pos1",
"accountId": 12345,
"contractId": "CON.F.US.MNQ.Z25",
"contractDisplayName": "MNQZ25",
"unknownGatewayField": "ignored",
"creationTimestamp": datetime.datetime.now(pytz.UTC).isoformat(),
"size": 2,
"averagePrice": 21342.25,
"type": 1,
},
],
}
trading_client._make_request.return_value = mock_response

positions = await trading_client.search_open_positions()

assert len(positions) == 1
assert positions[0].contractId == "CON.F.US.MNQ.Z25"
assert positions[0].contractDisplayName == "MNQZ25"
assert positions[0].size == 2

@pytest.mark.asyncio
async def test_search_open_positions_with_account_id(self, trading_client):
"""Test position search with specific account ID."""
Expand Down
3 changes: 3 additions & 0 deletions tests/types/test_api_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def test_position_response_structure(self):
assert hints["size"] is int
assert "averagePrice" in hints
assert hints["averagePrice"] is float
assert "contractDisplayName" in hints

def test_trade_response_structure(self):
"""Test TradeResponse has correct fields."""
Expand Down Expand Up @@ -267,6 +268,7 @@ def test_real_world_response_creation(self):
"id": 67890,
"accountId": 12345,
"contractId": "CON.F.US.MNQ.U25",
"contractDisplayName": "MNQU25",
"creationTimestamp": "2024-01-01T10:00:00Z",
"type": 1, # LONG
"size": 5,
Expand All @@ -275,6 +277,7 @@ def test_real_world_response_creation(self):

assert position["type"] == 1
assert position["size"] == 5
assert position["contractDisplayName"] == "MNQU25"

def test_market_data_responses(self):
"""Test market data response structures."""
Expand Down
6 changes: 6 additions & 0 deletions tests/types/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,12 @@ def test_basic_properties_and_indexing(self):
assert p.direction == "LONG"
assert p["averagePrice"] == pytest.approx(2050.0)
assert p.symbol == "MGC"
assert p.contractDisplayName is None

def test_contract_display_name(self):
p = self.make_position(contractDisplayName="MGCM25")
assert p.contractDisplayName == "MGCM25"
assert p["contractDisplayName"] == "MGCM25"

def test_short_position_helpers(self):
p = self.make_position(type=2, size=3)
Expand Down