diff --git a/src/apify/events/_apify_event_manager.py b/src/apify/events/_apify_event_manager.py index ca51618a..a9117d14 100644 --- a/src/apify/events/_apify_event_manager.py +++ b/src/apify/events/_apify_event_manager.py @@ -2,9 +2,12 @@ import asyncio import contextlib +import time from typing import TYPE_CHECKING, Annotated, Self import websockets.asyncio.client +import websockets.client +import websockets.exceptions from pydantic import Discriminator, TypeAdapter from typing_extensions import Unpack, override @@ -16,6 +19,7 @@ from apify.log import logger if TYPE_CHECKING: + from collections.abc import Generator from types import TracebackType from crawlee.events._event_manager import EventManagerOptions @@ -45,6 +49,17 @@ class ApifyEventManager(EventManager): with the event system. """ + _NON_RETRYABLE_CLOSE_CODES = frozenset({1002, 1003, 1007, 1008, 1010}) + """WebSocket close codes for a permanent condition, on which the connection is not re-established. + + The platform sends `1008` (policy violation) for an unknown/missing run ID or an exceeded per-run + connection limit. `1002`, `1003`, and `1007` are protocol or data errors, and `1010` a mandatory + extension failure. + """ + + _HEALTHY_CONNECTION_MIN_DURATION = 1.0 + """Seconds a connection must stay open to count as healthy, after which a drop reconnects without backoff.""" + def __init__(self, configuration: Configuration, **kwargs: Unpack[EventManagerOptions]) -> None: """Initialize a new instance. @@ -93,50 +108,121 @@ async def __aexit__( exc_value: BaseException | None, exc_traceback: TracebackType | None, ) -> None: - if self._platform_events_websocket: - await self._platform_events_websocket.close() - + # Cancel the task before closing the websocket so that the closed connection is not treated as a drop + # and followed by a reconnect attempt. if self._process_platform_messages_task and not self._process_platform_messages_task.done(): self._process_platform_messages_task.cancel() with contextlib.suppress(asyncio.CancelledError): await self._process_platform_messages_task + if self._platform_events_websocket: + await self._platform_events_websocket.close() + await super().__aexit__(exc_type, exc_value, exc_traceback) + def _process_connection_exception(self, exc: Exception) -> Exception | None: + """Decide whether a failed connection attempt to the platform websocket should be retried. + + Before the first successful connection, every error is fatal so that `__aenter__` fails fast. After that, + the default `websockets` behavior decides which errors are transient and retried with exponential backoff. + """ + if self._connected_to_platform_websocket and self._connected_to_platform_websocket.done(): + return websockets.asyncio.client.process_exception(exc) + return exc + async def _process_platform_messages(self, ws_url: str) -> None: + # The `websockets` reconnect iterator only backs off on failed connection *attempts*, not on a connection + # that opens and is then closed. Track our own backoff here so a server that keeps accepting and immediately + # closing is not hammered; it is reset after a healthy connection so a healthy drop reconnects immediately. + backoff_delays: Generator[float] | None = None + try: - async with websockets.asyncio.client.connect(ws_url) as websocket: + # Used as an async iterator, `connect` reconnects with exponential backoff on failed connection attempts. + async for websocket in websockets.asyncio.client.connect( + ws_url, process_exception=self._process_connection_exception + ): self._platform_events_websocket = websocket - if self._connected_to_platform_websocket is not None: + if self._connected_to_platform_websocket and not self._connected_to_platform_websocket.done(): self._connected_to_platform_websocket.set_result(True) - - async for message in websocket: - try: - parsed_message = event_data_adapter.validate_json(message) - - if isinstance(parsed_message, DeprecatedEvent): - continue - - if isinstance(parsed_message, UnknownEvent): - logger.info( - f'Unknown message received: event_name={parsed_message.name}, ' - f'event_data={parsed_message.data}' - ) - continue - - self.emit( - event=parsed_message.name, - event_data=parsed_message.data - if not isinstance(parsed_message.data, SystemInfoEventData) - else parsed_message.data.to_crawlee_format(self._configuration.dedicated_cpus or 1), - ) - - if parsed_message.name == Event.MIGRATING: - await self._emit_persist_state_event_rec_task.stop() - self.emit(event=Event.PERSIST_STATE, event_data=EventPersistStateData(is_migrating=True)) - except Exception: - logger.exception('Cannot parse Actor event', extra={'raw_message': message}) + else: + logger.info('Reconnected to the platform events websocket.') + + connection_opened_at = time.monotonic() + connection_lost = await self._consume_messages(websocket) + + if not self._should_reconnect_after_close(websocket, connection_lost=connection_lost): + break + + # Reconnect a healthy connection immediately; back off only on repeated rapid drops. + if time.monotonic() - connection_opened_at >= self._HEALTHY_CONNECTION_MIN_DURATION: + backoff_delays = None + elif backoff_delays is None: + backoff_delays = websockets.client.backoff() + else: + await asyncio.sleep(next(backoff_delays)) except Exception: logger.exception('Error in websocket connection') if self._connected_to_platform_websocket is not None and not self._connected_to_platform_websocket.done(): self._connected_to_platform_websocket.set_result(False) + + async def _consume_messages(self, websocket: websockets.asyncio.client.ClientConnection) -> bool: + """Handle platform messages until the connection closes; return whether it was lost vs. closed cleanly.""" + try: + async for message in websocket: + await self._handle_platform_message(message) + except websockets.exceptions.ConnectionClosed: + return True + return False + + async def _handle_platform_message(self, message: str | bytes) -> None: + """Parse a single platform message and emit the matching local event.""" + try: + parsed_message = event_data_adapter.validate_json(message) + + if isinstance(parsed_message, DeprecatedEvent): + return + + if isinstance(parsed_message, UnknownEvent): + logger.info( + f'Unknown message received: event_name={parsed_message.name}, event_data={parsed_message.data}' + ) + return + + self.emit( + event=parsed_message.name, + event_data=parsed_message.data + if not isinstance(parsed_message.data, SystemInfoEventData) + else parsed_message.data.to_crawlee_format(self._configuration.dedicated_cpus or 1), + ) + + if parsed_message.name == Event.MIGRATING: + await self._emit_persist_state_event_rec_task.stop() + self.emit(event=Event.PERSIST_STATE, event_data=EventPersistStateData(is_migrating=True)) + except Exception: + logger.exception('Cannot parse Actor event', extra={'raw_message': message}) + + def _should_reconnect_after_close( + self, + websocket: websockets.asyncio.client.ClientConnection, + *, + connection_lost: bool, + ) -> bool: + """Log the websocket close and report whether to reconnect (`False` on a non-retryable close code).""" + if websocket.close_code in self._NON_RETRYABLE_CLOSE_CODES: + logger.error( + f'Connection to platform events websocket was closed with a non-retryable code ' + f'(code={websocket.close_code}, reason={websocket.close_reason!r}); not reconnecting.' + ) + return False + + if connection_lost: + logger.warning( + f'Connection to platform events websocket was lost ' + f'(code={websocket.close_code}, reason={websocket.close_reason!r}), reconnecting...' + ) + else: + logger.info( + f'Connection to platform events websocket was closed ' + f'(code={websocket.close_code}, reason={websocket.close_reason!r}), reconnecting...' + ) + return True diff --git a/tests/unit/events/test_apify_event_manager.py b/tests/unit/events/test_apify_event_manager.py index 21ed00bd..eb8118c3 100644 --- a/tests/unit/events/test_apify_event_manager.py +++ b/tests/unit/events/test_apify_event_manager.py @@ -4,6 +4,8 @@ import contextlib import json import logging +import socket +import types from collections import defaultdict from datetime import timedelta from typing import TYPE_CHECKING, Any @@ -12,7 +14,6 @@ import pytest import websockets import websockets.asyncio.server -import websockets.exceptions from crawlee.events._types import Event @@ -23,7 +24,19 @@ from apify.events._types import SystemInfoEventData if TYPE_CHECKING: - from collections.abc import AsyncGenerator, Callable + from collections.abc import AsyncGenerator, Awaitable, Callable + + +DUMMY_SYSTEM_INFO = { + 'memAvgBytes': 19328860.328293584, + 'memCurrentBytes': 65171456, + 'memMaxBytes': 65171456, + 'cpuAvgUsage': 2.0761105633130397, + 'cpuMaxUsage': 53.941134593993326, + 'cpuCurrentUsage': 8.45549815498155, + 'isCpuOverloaded': False, + 'createdAt': '2024-08-09T16:04:16.161Z', +} @contextlib.asynccontextmanager @@ -57,6 +70,74 @@ async def handler(websocket: websockets.asyncio.server.ServerConnection) -> None yield connected_ws_clients, client_connected +@contextlib.asynccontextmanager +async def _restartable_ws_server( + monkeypatch: pytest.MonkeyPatch, + *, + on_connect: Callable[[websockets.asyncio.server.ServerConnection], Awaitable[None]] | None = None, +) -> AsyncGenerator[Any]: + """A local `127.0.0.1` WebSocket server that can be stopped/restarted and counts connection attempts. + + Binds to a fixed free port (reserved up front) so a restart can reuse the same address, letting a test simulate the + platform server going away and coming back. Yields a control namespace with `live_clients`, a re-armable + `client_connected` event, a cumulative `attempts()` counter, and `stop()` / `start()` coroutines. Pass `on_connect` + to take over a freshly accepted connection (e.g. immediately close it with a chosen code). + """ + # Reserve a fixed free port so a restart can re-serve on the same address. + probe = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + probe.bind(('127.0.0.1', 0)) + port = probe.getsockname()[1] + probe.close() + + live_clients: set[websockets.asyncio.server.ServerConnection] = set() + client_connected = asyncio.Event() + attempts = 0 + server_holder: dict[str, Any] = {'srv': None} + + async def handler(websocket: websockets.asyncio.server.ServerConnection) -> None: + nonlocal attempts + attempts += 1 + if on_connect is not None: + await on_connect(websocket) + return + live_clients.add(websocket) + client_connected.set() + try: + await websocket.wait_closed() + finally: + live_clients.discard(websocket) + + async def _serve() -> None: + server_holder['srv'] = await websockets.asyncio.server.serve(handler, host='127.0.0.1', port=port) + + async def stop() -> None: + srv = server_holder['srv'] + if srv is not None: + srv.close() + await srv.wait_closed() + server_holder['srv'] = None + # Drop any live connection so the client is forced into reconnect mode. + for websocket in list(live_clients): + await websocket.close() + + async def start() -> None: + await asyncio.sleep(0.3) # Give the OS a moment to release the port before re-serving. + await _serve() + + monkeypatch.setenv(ActorEnvVars.EVENTS_WEBSOCKET_URL, f'ws://127.0.0.1:{port}') + await _serve() + try: + yield types.SimpleNamespace( + live_clients=live_clients, + client_connected=client_connected, + attempts=lambda: attempts, + stop=stop, + start=start, + ) + finally: + await stop() + + async def test_lifecycle_local(caplog: pytest.LogCaptureFixture) -> None: caplog.set_level(logging.DEBUG, logger='apify') @@ -194,17 +275,7 @@ async def send_platform_event(event_name: Event, data: Any = None) -> None: websockets.broadcast(connected_ws_clients, json.dumps(message)) - dummy_system_info = { - 'memAvgBytes': 19328860.328293584, - 'memCurrentBytes': 65171456, - 'memMaxBytes': 65171456, - 'cpuAvgUsage': 2.0761105633130397, - 'cpuMaxUsage': 53.941134593993326, - 'cpuCurrentUsage': 8.45549815498155, - 'isCpuOverloaded': False, - 'createdAt': '2024-08-09T16:04:16.161Z', - } - SystemInfoEventData.model_validate(dummy_system_info) + SystemInfoEventData.model_validate(DUMMY_SYSTEM_INFO) async with ApifyEventManager(Configuration.get_global_configuration()) as event_manager: await client_connected.wait() @@ -216,7 +287,7 @@ def listener(data: Any) -> None: event_manager.on(event=Event.SYSTEM_INFO, listener=listener) # Test sending event with data - await send_platform_event(Event.SYSTEM_INFO, dummy_system_info) + await send_platform_event(Event.SYSTEM_INFO, DUMMY_SYSTEM_INFO) await poll_until_condition(lambda: len(event_calls) == 1, poll_interval=0.05) assert len(event_calls) == 1 assert event_calls[0] is not None @@ -325,38 +396,191 @@ def migrating_listener(data: Any) -> None: assert len(migration_persist_events) >= 1 -async def test_websocket_mid_stream_disconnect_does_not_raise_invalid_state_error( - monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture +@pytest.mark.parametrize( + ('close_code', 'expected_log'), + [ + pytest.param(1000, 'Connection to platform events websocket was closed (code=1000', id='graceful_close'), + pytest.param(1011, 'Connection to platform events websocket was lost (code=1011', id='abnormal_close'), + ], +) +async def test_websocket_reconnects_after_connection_drop( + monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture, close_code: int, expected_log: str ) -> None: - """Regression: a mid-stream websocket disconnect after a successful connect must not raise InvalidStateError. + """Test that the event manager logs a websocket drop, reconnects, and keeps receiving platform events. - The `_connected_to_platform_websocket` future is resolved to `True` on successful connect. If the websocket - later drops, the outer `except` in `_process_platform_messages` must not call `set_result(False)` on the - already-resolved future. + Also a regression test for the resolved `_connected_to_platform_websocket` future: a mid-stream disconnect + must not kill the message-processing task with `InvalidStateError`. """ + caplog.set_level(logging.INFO, logger='apify') async with ( _platform_ws_server(monkeypatch) as (connected_ws_clients, client_connected), ApifyEventManager(Configuration.get_global_configuration()) as event_manager, ): await client_connected.wait() + assert len(connected_ws_clients) == 1 + + event_calls: list[Any] = [] + event_manager.on(event=Event.SYSTEM_INFO, listener=event_calls.append) - # Force an abnormal close from the server so the client's `async for` raises ConnectionClosedError. + # Drop the connection from the server side and wait for the client to reconnect. + client_connected.clear() for ws in list(connected_ws_clients): - await ws.close(code=1011, reason='Simulated server error') + await ws.close(code=close_code, reason='Simulated connection drop') + await asyncio.wait_for(client_connected.wait(), timeout=10) + # Poll because the old server-side handler may not have deregistered its connection yet. + await poll_until_condition(lambda: len(connected_ws_clients) == 1, poll_interval=0.05) + assert len(connected_ws_clients) == 1 + + # The message-processing task must have survived the drop. + task = event_manager._process_platform_messages_task + assert task is not None + assert not task.done() + + # Events sent over the new connection must still be emitted. + websockets.broadcast(connected_ws_clients, json.dumps({'name': 'systemInfo', 'data': DUMMY_SYSTEM_INFO})) + await poll_until_condition(lambda: len(event_calls) == 1, poll_interval=0.05) + assert len(event_calls) == 1 + + # Both the drop and the successful reconnect must be logged. + assert expected_log in caplog.text + assert 'Reconnected to the platform events websocket.' in caplog.text + + +async def test_non_retryable_close_stops_reconnecting( + monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture +) -> None: + """Test that a non-retryable close code (1008) stops reconnection instead of looping forever.""" + caplog.set_level(logging.ERROR, logger='apify') + + async def close_with_policy_violation(websocket: websockets.asyncio.server.ServerConnection) -> None: + await websocket.close(code=1008, reason='policy violation') + async with ( + _restartable_ws_server(monkeypatch, on_connect=close_with_policy_violation) as server, + ApifyEventManager(Configuration.get_global_configuration()) as event_manager, + ): task = event_manager._process_platform_messages_task assert task is not None - await asyncio.wait_for(asyncio.shield(task), timeout=2.0) - exc = task.exception() - assert not isinstance(exc, asyncio.InvalidStateError), f'Task raised InvalidStateError: {exc}' + # After a non-retryable close the processing task must give up rather than reconnect forever. + await poll_until_condition(task.done, poll_interval=0.05) + assert task.done() + assert server.attempts() <= 5, f'reconnected after a non-retryable close: {server.attempts()} attempts' - # Confirm the test actually exercised the disconnect path — the outer `except` in - # `_process_platform_messages` should have logged a `ConnectionClosedError`. - logged_exc_types = [ - record.exc_info[0] for record in caplog.records if record.exc_info and record.exc_info[0] is not None - ] - assert any(issubclass(exc_type, websockets.exceptions.ConnectionClosedError) for exc_type in logged_exc_types) + assert 'non-retryable code' in caplog.text + + +async def test_rapid_retryable_close_backs_off( + monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture +) -> None: + """Test that repeated retryable closes are retried with backoff instead of a tight reconnect loop.""" + caplog.set_level(logging.WARNING, logger='apify') + + async def close_with_internal_error(websocket: websockets.asyncio.server.ServerConnection) -> None: + await websocket.close(code=1011, reason='internal error') + + async with ( + _restartable_ws_server(monkeypatch, on_connect=close_with_internal_error) as server, + ApifyEventManager(Configuration.get_global_configuration()) as event_manager, + ): + task = event_manager._process_platform_messages_task + assert task is not None + + # Without backoff a tight loop would make thousands of attempts in this window; backoff keeps it tiny. + await asyncio.sleep(2) + assert not task.done() + attempts = server.attempts() + + assert 0 < attempts <= 15, f'client busy-looped on a retryable close: {attempts} attempts in 2s' + assert 'was lost (code=1011' in caplog.text + + +async def test_rapid_retryable_close_after_event_backs_off(monkeypatch: pytest.MonkeyPatch) -> None: + """Test that a server that delivers an event before each retryable close is still retried with backoff.""" + + async def send_event_then_close(websocket: websockets.asyncio.server.ServerConnection) -> None: + await websocket.send(json.dumps({'name': 'systemInfo', 'data': DUMMY_SYSTEM_INFO})) + await websocket.close(code=1011, reason='internal error') + + async with ( + _restartable_ws_server(monkeypatch, on_connect=send_event_then_close) as server, + ApifyEventManager(Configuration.get_global_configuration()) as event_manager, + ): + task = event_manager._process_platform_messages_task + assert task is not None + + # A short-lived connection must back off even though it delivered an event, or it would busy-loop. + await asyncio.sleep(2) + assert not task.done() + attempts = server.attempts() + + assert 0 < attempts <= 15, f'client busy-looped after a message-bearing close: {attempts} attempts in 2s' + + +async def test_reconnects_after_server_becomes_unreachable( + monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture +) -> None: + """Test that the client survives a server outage, keeps retrying, and resumes events once the server returns.""" + caplog.set_level(logging.INFO, logger='apify') + + async with ( + _restartable_ws_server(monkeypatch) as server, + ApifyEventManager(Configuration.get_global_configuration()) as event_manager, + ): + await asyncio.wait_for(server.client_connected.wait(), timeout=10) + assert len(server.live_clients) == 1 + + event_calls: list[Any] = [] + event_manager.on(event=Event.SYSTEM_INFO, listener=event_calls.append) + + # Take the server down and drop the live connection: every reconnect attempt now hits connection-refused. + server.client_connected.clear() + await server.stop() + task = event_manager._process_platform_messages_task + assert task is not None + + # During the outage the task must keep retrying instead of crashing or exiting. + await asyncio.sleep(1) + assert not task.done() + + # Bring the server back on the same port; the client must reconnect within a bounded time. + await server.start() + await asyncio.wait_for(server.client_connected.wait(), timeout=10) + await poll_until_condition(lambda: len(server.live_clients) == 1, poll_interval=0.05) + assert len(server.live_clients) == 1 + + # Events sent over the recovered connection must still be delivered. + websockets.broadcast(server.live_clients, json.dumps({'name': 'systemInfo', 'data': DUMMY_SYSTEM_INFO})) + await poll_until_condition(lambda: len(event_calls) == 1, poll_interval=0.05) + assert len(event_calls) == 1 + + assert 'Reconnected to the platform events websocket.' in caplog.text + + +async def test_shutdown_during_reconnect_backoff_is_clean(monkeypatch: pytest.MonkeyPatch) -> None: + """Test that exiting the event manager while it is mid-reconnect (server down) shuts down cleanly.""" + async with _restartable_ws_server(monkeypatch) as server: + event_manager = ApifyEventManager(Configuration.get_global_configuration()) + async with event_manager: + await asyncio.wait_for(server.client_connected.wait(), timeout=10) + assert len(server.live_clients) == 1 + + # Force the client into reconnect/backoff: server down, live connection dropped. + server.client_connected.clear() + await server.stop() + task = event_manager._process_platform_messages_task + assert task is not None + await asyncio.sleep(0.5) + assert not task.done() + # __aexit__ runs here, while the client is between reconnect attempts. + + # The processing task must be finished and cancelled, not crashed with a stray error. + assert task.done() + assert task.cancelled() or task.exception() is None + assert event_manager.active is False + # The parent recurring persist-state task must be stopped too, mirroring the failed-connect lifecycle test. + persist_state_task = event_manager._emit_persist_state_event_rec_task.task + assert persist_state_task is None or persist_state_task.done() async def test_malformed_message_logs_exception(