-
Notifications
You must be signed in to change notification settings - Fork 3.5k
[WIP] implement pod exec v5 #2486
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -40,6 +40,9 @@ | |
| ERROR_CHANNEL = 3 | ||
| RESIZE_CHANNEL = 4 | ||
|
|
||
| V4_CHANNEL_PROTOCOL = "v4.channel.k8s.io" | ||
| V5_CHANNEL_PROTOCOL = "v5.channel.k8s.io" | ||
|
|
||
| class _IgnoredIO: | ||
| def write(self, _x): | ||
| pass | ||
|
|
@@ -59,13 +62,16 @@ def __init__(self, configuration, url, headers, capture_all, binary=False): | |
| """ | ||
| self._connected = False | ||
| self._channels = {} | ||
| self._closed_channels = set() | ||
| self.subprotocol = None | ||
| self.binary = binary | ||
| self.newline = '\n' if not self.binary else b'\n' | ||
| if capture_all: | ||
| self._all = StringIO() if not self.binary else BytesIO() | ||
| else: | ||
| self._all = _IgnoredIO() | ||
| self.sock = create_websocket(configuration, url, headers) | ||
| self.subprotocol = getattr(self.sock, 'subprotocol', None) | ||
| self._connected = True | ||
| self._returncode = None | ||
|
|
||
|
|
@@ -93,6 +99,7 @@ def readline_channel(self, channel, timeout=None): | |
| timeout = float("inf") | ||
| start = time.time() | ||
| while self.is_open() and time.time() - start < timeout: | ||
| # Always try to drain the channel first | ||
| if channel in self._channels: | ||
| data = self._channels[channel] | ||
| if self.newline in data: | ||
|
|
@@ -104,6 +111,14 @@ def readline_channel(self, channel, timeout=None): | |
| else: | ||
| del self._channels[channel] | ||
| return ret | ||
|
|
||
| if channel in self._closed_channels: | ||
| if channel in self._channels: | ||
| ret = self._channels[channel] | ||
| del self._channels[channel] | ||
| return ret | ||
| return b"" if self.binary else "" | ||
|
|
||
| self.update(timeout=(timeout - time.time() + start)) | ||
|
|
||
| def write_channel(self, channel, data): | ||
|
|
@@ -119,6 +134,14 @@ def write_channel(self, channel, data): | |
| payload = channel_prefix + data | ||
| self.sock.send(payload, opcode=opcode) | ||
|
|
||
| def close_channel(self, channel): | ||
| """Close a channel (v5 protocol only).""" | ||
| if self.subprotocol != V5_CHANNEL_PROTOCOL: | ||
| return | ||
| data = bytes([255, channel]) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we define a constant (e.g. CLOSE_CHANNEL = 255 or V5_HALF_CLOSE = 255) for this magic number, and reference the constant? |
||
| self.sock.send(data, opcode=ABNF.OPCODE_BINARY) | ||
| self._closed_channels.add(channel) | ||
|
|
||
| def peek_stdout(self, timeout=0): | ||
| """Same as peek_channel with channel=1.""" | ||
| return self.peek_channel(STDOUT_CHANNEL, timeout=timeout) | ||
|
|
@@ -200,13 +223,24 @@ def update(self, timeout=0): | |
| return | ||
| elif op_code == ABNF.OPCODE_BINARY or op_code == ABNF.OPCODE_TEXT: | ||
| data = frame.data | ||
| if six.PY3 and not self.binary: | ||
| data = data.decode("utf-8", "replace") | ||
| if len(data) > 1: | ||
| if len(data) > 0: | ||
| # Parse channel from raw bytes to support v5 CLOSE signal AND avoid charset issues | ||
| channel = data[0] | ||
| if six.PY3 and not self.binary: | ||
| channel = ord(channel) | ||
| # In Py3, iterating bytes gives int, but indexing bytes gives int. | ||
| # websocket-client frame.data might be bytes. | ||
|
|
||
| if channel == 255 and self.subprotocol == V5_CHANNEL_PROTOCOL: # v5 CLOSE | ||
| if len(data) > 1: | ||
| # data[1] is already int in Py3 bytes | ||
| close_chan = data[1] | ||
| self._closed_channels.add(close_chan) | ||
| return | ||
|
|
||
| data = data[1:] | ||
| # Decode data if expected text | ||
| if not self.binary: | ||
| data = data.decode("utf-8", "replace") | ||
|
|
||
| if data: | ||
| if channel in [STDOUT_CHANNEL, STDERR_CHANNEL]: | ||
| # keeping all messages in the order they received | ||
|
|
@@ -469,7 +503,7 @@ def create_websocket(configuration, url, headers=None): | |
| header.append("sec-websocket-protocol: %s" % | ||
| headers['sec-websocket-protocol']) | ||
| else: | ||
| header.append("sec-websocket-protocol: v4.channel.k8s.io") | ||
| header.append("sec-websocket-protocol: %s,%s" % (V5_CHANNEL_PROTOCOL, V4_CHANNEL_PROTOCOL)) | ||
|
|
||
| if url.startswith('wss://') and configuration.verify_ssl: | ||
| ssl_opts = { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,15 +13,18 @@ | |
| # limitations under the License. | ||
|
|
||
| import unittest | ||
| from unittest.mock import MagicMock, patch | ||
|
|
||
| from .ws_client import get_websocket_url | ||
| from . import ws_client as ws_client_module | ||
| from .ws_client import get_websocket_url, WSClient, V5_CHANNEL_PROTOCOL, V4_CHANNEL_PROTOCOL | ||
| from .ws_client import websocket_proxycare | ||
| from kubernetes.client.configuration import Configuration | ||
| import os | ||
| import socket | ||
| import threading | ||
| import pytest | ||
| from kubernetes import stream, client, config | ||
| import websocket | ||
|
|
||
| try: | ||
| import urllib3 | ||
|
|
@@ -123,6 +126,117 @@ def test_websocket_proxycare(self): | |
| assert dictval(connect_opts, 'http_proxy_auth') == expect_auth | ||
| assert dictval(connect_opts, 'http_no_proxy') == expect_noproxy | ||
|
|
||
|
|
||
| class WSClientProtocolTest(unittest.TestCase): | ||
| """Tests for WSClient V5 protocol handling""" | ||
|
|
||
| def setUp(self): | ||
| # Mock configuration to avoid real connections in WSClient.__init__ | ||
| self.config_mock = MagicMock() | ||
| self.config_mock.assert_hostname = False | ||
| self.config_mock.api_key = {} | ||
| self.config_mock.proxy = None | ||
| self.config_mock.ssl_ca_cert = None | ||
| self.config_mock.cert_file = None | ||
| self.config_mock.key_file = None | ||
| self.config_mock.verify_ssl = True | ||
|
|
||
| def test_create_websocket_header(self): | ||
| """Verify sec-websocket-protocol header requests v5 first""" | ||
| # Patch WebSocket class in the module | ||
| with patch.object(ws_client_module, 'WebSocket', autospec=True) as mock_ws_cls: | ||
| mock_ws = mock_ws_cls.return_value | ||
|
|
||
| WSClient(self.config_mock, "ws://test", headers=None, capture_all=True) | ||
|
|
||
| mock_ws.connect.assert_called_once() | ||
| call_args = mock_ws.connect.call_args | ||
| # connect(url, **options) | ||
| # check kwargs for 'header' | ||
| kwargs = call_args[1] | ||
| self.assertIn('header', kwargs) | ||
| expected_header = f"sec-websocket-protocol: {V5_CHANNEL_PROTOCOL},{V4_CHANNEL_PROTOCOL}" | ||
| self.assertIn(expected_header, kwargs['header']) | ||
|
|
||
| def test_close_channel_v5(self): | ||
| """Verify close_channel sends correct frame when v5 is negotiated""" | ||
| with patch.object(ws_client_module, 'create_websocket') as mock_create: | ||
| mock_ws = MagicMock() | ||
| mock_ws.subprotocol = V5_CHANNEL_PROTOCOL | ||
| mock_ws.connected = True | ||
| mock_create.return_value = mock_ws | ||
|
|
||
| client = WSClient(self.config_mock, "ws://test", headers=None, capture_all=True) | ||
| client.close_channel(0) | ||
|
|
||
| mock_ws.send.assert_called_with(b'\xff\x00', opcode=websocket.ABNF.OPCODE_BINARY) | ||
|
|
||
| def test_close_channel_v4(self): | ||
| """Verify close_channel does nothing when v4 is negotiated""" | ||
| with patch.object(ws_client_module, 'create_websocket') as mock_create: | ||
| mock_ws = MagicMock() | ||
| mock_ws.subprotocol = V4_CHANNEL_PROTOCOL | ||
| mock_ws.connected = True | ||
| mock_create.return_value = mock_ws | ||
|
|
||
| client = WSClient(self.config_mock, "ws://test", headers=None, capture_all=True) | ||
| client.close_channel(0) | ||
|
|
||
| mock_ws.send.assert_not_called() | ||
|
|
||
| def test_update_receives_close_v5(self): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we add an additional unit test to verify how While the current unit tests cover the parsing of the close signal itself, the new logic you added inside |
||
| """Verify update processes close signal when v5 is negotiated""" | ||
| with patch.object(ws_client_module, 'create_websocket') as mock_create, \ | ||
| patch('select.select') as mock_select: | ||
|
|
||
| mock_ws = MagicMock() | ||
| mock_ws.subprotocol = V5_CHANNEL_PROTOCOL | ||
| mock_ws.connected = True | ||
| mock_ws.sock.fileno.return_value = 10 | ||
|
|
||
| # Setup frame with close signal for channel 0 | ||
| frame = MagicMock() | ||
| frame.data = b'\xff\x00' | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would this be better as constants, such as |
||
| mock_ws.recv_data_frame.return_value = (websocket.ABNF.OPCODE_BINARY, frame) | ||
|
|
||
| mock_create.return_value = mock_ws | ||
| # Make select return ready | ||
| mock_select.return_value = ([mock_ws.sock], [], []) | ||
|
|
||
| client = WSClient(self.config_mock, "ws://test", headers=None, capture_all=True) | ||
| client.update(timeout=0) | ||
|
|
||
| self.assertIn(0, client._closed_channels) | ||
|
|
||
| def test_update_ignores_close_signal_v4(self): | ||
| """Verify update treats 0xFF as regular data (or ignores signal interpretation) when v4""" | ||
| with patch.object(ws_client_module, 'create_websocket') as mock_create, \ | ||
| patch('select.select') as mock_select: | ||
|
|
||
| mock_ws = MagicMock() | ||
| mock_ws.subprotocol = V4_CHANNEL_PROTOCOL | ||
| mock_ws.connected = True | ||
| mock_ws.sock.fileno.return_value = 10 | ||
|
|
||
| # Setup frame that looks like close signal but should be treated as data | ||
| frame = MagicMock() | ||
| frame.data = b'\xff\x00' | ||
| mock_ws.recv_data_frame.return_value = (websocket.ABNF.OPCODE_BINARY, frame) | ||
|
|
||
| mock_create.return_value = mock_ws | ||
| mock_select.return_value = ([mock_ws.sock], [], []) | ||
|
|
||
| client = WSClient(self.config_mock, "ws://test", headers=None, capture_all=True, binary=True) # binary=True to avoid decode errors | ||
| client.update(timeout=0) | ||
|
|
||
| # Should NOT be in closed channels | ||
| self.assertNotIn(0, client._closed_channels) | ||
| # Should be in data channels (channel 255 with data \x00) | ||
| # Code: channel = data[0] (255), data = data[1:] (\x00) | ||
| # if channel (255) not in _channels... | ||
| self.assertIn(255, client._channels) | ||
| self.assertEqual(client._channels[255], b'\x00') | ||
|
|
||
| @pytest.fixture(scope="module") | ||
| def dummy_proxy(): | ||
| #Dummy Proxy | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -201,6 +201,63 @@ def test_pod_apis(self): | |
| resp = api.delete_namespaced_pod(name=name, body={}, | ||
| namespace='default') | ||
|
|
||
| def test_pod_exec_close_channel(self): | ||
| """Test sending CLOSE signal for a channel (v5 protocol).""" | ||
| client = api_client.ApiClient(configuration=self.config) | ||
| api = core_v1_api.CoreV1Api(client) | ||
|
|
||
| name = 'busybox-test-' + short_uuid() | ||
| pod_manifest = manifest_with_command( | ||
| name, "while true;do date;sleep 5; done") | ||
|
|
||
| resp = api.create_namespaced_pod(body=pod_manifest, namespace='default') | ||
| self.assertEqual(name, resp.metadata.name) | ||
|
|
||
| # Wait for pod to be running | ||
| timeout = time.time() + 60 | ||
| while True: | ||
| resp = api.read_namespaced_pod(name=name, namespace='default') | ||
| if resp.status.phase == 'Running': | ||
| break | ||
| if time.time() > timeout: | ||
| self.fail("Timeout waiting for pod to be running") | ||
| time.sleep(1) | ||
|
|
||
| # Use cat to echo stdin to stdout. | ||
| # When stdin is closed, cat should exit, terminating the command. | ||
| resp = stream(api.connect_post_namespaced_pod_exec, name, 'default', | ||
| command=['/bin/sh', '-c', 'cat'], | ||
| stderr=True, stdin=True, | ||
| stdout=True, tty=False, | ||
| _preload_content=False) | ||
|
|
||
| if resp.subprotocol != "v5.channel.k8s.io": | ||
| resp.close() | ||
| api.delete_namespaced_pod(name=name, body={}, namespace='default') | ||
| self.skipTest("Skipping test: v5.channel.k8s.io subprotocol not negotiated") | ||
|
|
||
| try: | ||
| resp.write_stdin("test-close\n") | ||
| line = resp.readline_stdout(timeout=5) | ||
| self.assertEqual("test-close", line) | ||
|
|
||
| # Close stdin (channel 0) | ||
| # This should send EOF to cat, causing it to exit. | ||
| resp.close_channel(0) | ||
|
|
||
| # Wait for process to exit | ||
| resp.update(timeout=5) | ||
| start = time.time() | ||
| while resp.is_open() and time.time() - start < 10: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we simplify this polling loop by using resp.run_forever(timeout=15) instead of manually checking resp.is_open()? This would align it closely with how the other exec tests in this file wait for streams to close. |
||
| resp.update(timeout=1) | ||
|
|
||
| self.assertFalse(resp.is_open(), "Connection should close after cat exits") | ||
| self.assertEqual(resp.returncode, 0) | ||
| finally: | ||
| if resp.is_open(): | ||
| resp.close() | ||
| api.delete_namespaced_pod(name=name, body={}, namespace='default') | ||
|
|
||
| def test_exit_code(self): | ||
| client = api_client.ApiClient(configuration=self.config) | ||
| api = core_v1_api.CoreV1Api(client) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need this same short-circuiting code in
peek_channelandread_channel?