Source code for coinbase.websocket.websocket_base

import asyncio
import json
import logging
import os
import ssl
import threading
import time
from multiprocessing import AuthenticationError
from typing import IO, Callable, List, Optional, Union

import backoff
import websockets

from coinbase import jwt_generator
from coinbase.api_base import APIBase, get_logger
from coinbase.constants import (
    API_ENV_KEY,
    API_SECRET_ENV_KEY,
    SUBSCRIBE_MESSAGE_TYPE,
    UNSUBSCRIBE_MESSAGE_TYPE,
    USER_AGENT,
    WS_AUTH_CHANNELS,
    WS_BASE_URL,
    WS_RETRY_BASE,
    WS_RETRY_FACTOR,
    WS_RETRY_MAX,
)

logger = get_logger("coinbase.WSClient")


[docs] class WSClientException(Exception): """ **WSClientException** ________________________________________ ----------------------------------------- Exception raised for errors in the WebSocket client. """ pass
[docs] class WSClientConnectionClosedException(Exception): """ **WSClientConnectionClosedException** ________________________________________ ---------------------------------------- Exception raised for unexpected closure in the WebSocket client. """ pass
class WSBase(APIBase): """ :meta private: """ def __init__( self, api_key: Optional[str] = os.getenv(API_ENV_KEY), api_secret: Optional[str] = os.getenv(API_SECRET_ENV_KEY), key_file: Optional[Union[IO, str]] = None, base_url=WS_BASE_URL, timeout: Optional[int] = None, max_size: Optional[int] = 10 * 1024 * 1024, on_message: Optional[Callable[[str], None]] = None, on_open: Optional[Callable[[], None]] = None, on_close: Optional[Callable[[], None]] = None, retry: Optional[bool] = True, verbose: Optional[bool] = False, ): super().__init__( api_key=api_key, api_secret=api_secret, key_file=key_file, base_url=base_url, timeout=timeout, verbose=verbose, ) if not on_message: raise WSClientException("on_message callback is required.") if verbose: logger.setLevel(logging.DEBUG) self.max_size = max_size self.on_message = on_message self.on_open = on_open self.on_close = on_close self.websocket = None self.loop = None self.thread = None self._task = None self.retry = retry self._retry_max_tries = WS_RETRY_MAX self._retry_base = WS_RETRY_BASE self._retry_factor = WS_RETRY_FACTOR self._retry_count = 0 self.subscriptions = {} self._background_exception = None self._retrying = False def open(self) -> None: """ **Open Websocket** __________________ ------------------------ Open the websocket client connection. """ if not self.loop or self.loop.is_closed(): self.loop = asyncio.new_event_loop() # Create a new event loop self.thread = threading.Thread(target=self.loop.run_forever) self.thread.daemon = True self.thread.start() self._run_coroutine_threadsafe(self.open_async()) async def open_async(self) -> None: """ **Open Websocket Async** ________________________ ------------------------ Open the websocket client connection asynchronously. """ self._ensure_websocket_not_open() headers = self._set_headers() logger.debug("Connecting to %s", self.base_url) try: self.websocket = await websockets.connect( self.base_url, open_timeout=self.timeout, max_size=self.max_size, user_agent_header=USER_AGENT, extra_headers=headers, ssl=ssl.SSLContext() if self.base_url.startswith("wss://") else None, ) logger.debug("Successfully connected to %s", self.base_url) if self.on_open: self.on_open() # Start the message handler coroutine after establishing connection if not self._retrying: self._task = asyncio.create_task(self._message_handler()) except asyncio.TimeoutError as toe: self.websocket = None logger.error("Connection attempt timed out: %s", toe) raise WSClientException("Connection attempt timed out") from toe except (websockets.exceptions.WebSocketException, OSError) as wse: self.websocket = None logger.error("Failed to establish WebSocket connection: %s", wse) raise WSClientException("Failed to establish WebSocket connection") from wse def close(self) -> None: """ **Close Websocket** ___________________ ------------------------ Close the websocket client connection. """ if self.loop and not self.loop.is_closed(): # Schedule the asynchronous close self._run_coroutine_threadsafe(self.close_async()) # Stop the event loop self.loop.call_soon_threadsafe(self.loop.stop) # Wait for the thread to finish self.thread.join() # Close the event loop self.loop.close() else: raise WSClientException("Event loop is not running.") async def close_async(self) -> None: """ **Close Websocket Async** _________________________ ------------------------ Close the websocket client connection asynchronously. """ self._ensure_websocket_open() logger.debug("Closing connection to %s", self.base_url) try: await self.websocket.close() self.websocket = None self.subscriptions = {} logger.debug("Connection closed to %s", self.base_url) if self.on_close: self.on_close() except (websockets.exceptions.WebSocketException, OSError) as wse: logger.error("Failed to close WebSocket connection: %s", wse) raise WSClientException("Failed to close WebSocket connection.") from wse def subscribe(self, product_ids: List[str], channels: List[str]) -> None: """ **Subscribe** _____________ ------------------------ Subscribe to a list of channels for a list of product ids. - **product_ids** - product ids to subscribe to - **channels** - channels to subscribe to """ if self.loop and not self.loop.is_closed(): self._run_coroutine_threadsafe(self.subscribe_async(product_ids, channels)) else: raise WSClientException("Websocket Client is not open.") async def subscribe_async( self, product_ids: List[str], channels: List[str] ) -> None: """ **Subscribe Async** ___________________ ------------------------ Async subscribe to a list of channels for a list of product ids. - **product_ids** - product ids to subscribe to - **channels** - channels to subscribe to """ self._ensure_websocket_open() for channel in channels: try: if not self.is_authenticated and channel in WS_AUTH_CHANNELS: raise AuthenticationError( "Unauthenticated request to private channel." ) is_public = False if channel in WS_AUTH_CHANNELS else True message = self._build_subscription_message( product_ids, channel, SUBSCRIBE_MESSAGE_TYPE, is_public ) json_message = json.dumps(message) logger.debug( "Subscribing to channel %s for product IDs: %s", channel, product_ids, ) await self.websocket.send(json_message) logger.debug("Successfully sent subscription message.") # add to subscriptions map if channel not in self.subscriptions: self.subscriptions[channel] = set() self.subscriptions[channel].update(product_ids) except websockets.exceptions.WebSocketException as wse: logger.error( "Failed to subscribe to %s channel for product IDs %s: %s", channel, product_ids, wse, ) raise WSClientException( f"Failed to subscribe to {channel} channel for product ids {product_ids}." ) from wse def unsubscribe(self, product_ids: List[str], channels: List[str]) -> None: """ **Unsubscribe** _______________ ------------------------ Unsubscribe to a list of channels for a list of product ids. - **product_ids** - product ids to unsubscribe from - **channels** - channels to unsubscribe from """ if self.loop and not self.loop.is_closed(): self._run_coroutine_threadsafe( self.unsubscribe_async(product_ids, channels) ) else: raise WSClientException("Websocket Client is not open.") async def unsubscribe_async( self, product_ids: List[str], channels: List[str] ) -> None: """ **Unsubscribe Async** _____________________ ------------------------ Async unsubscribe to a list of channels for a list of product ids. - **product_ids** - product ids to unsubscribe from - **channels** - channels to unsubscribe from """ self._ensure_websocket_open() for channel in channels: try: if not self.is_authenticated and channel in WS_AUTH_CHANNELS: raise AuthenticationError( "Unauthenticated request to private channel. If you wish to access private channels, you must provide your API key and secret when initializing the WSClient." ) is_public = False if channel in WS_AUTH_CHANNELS else True message = self._build_subscription_message( product_ids, channel, UNSUBSCRIBE_MESSAGE_TYPE, is_public ) json_message = json.dumps(message) logger.debug( "Unsubscribing from channel %s for product IDs: %s", channel, product_ids, ) await self.websocket.send(json_message) logger.debug("Successfully sent unsubscribe message.") # remove from subscriptions map if channel in self.subscriptions: self.subscriptions[channel].difference_update(product_ids) except (websockets.exceptions.WebSocketException, OSError) as wse: logger.error( "Failed to unsubscribe to %s channel for product IDs %s: %s", channel, product_ids, wse, ) raise WSClientException( f"Failed to unsubscribe to {channel} channel for product ids {product_ids}." ) from wse def unsubscribe_all(self) -> None: """ **Unsubscribe All** ________________________ ------------------------ Unsubscribe from all channels you are currently subscribed to. """ if self.loop and not self.loop.is_closed(): self._run_coroutine_threadsafe(self.unsubscribe_all_async()) else: raise WSClientException("Websocket Client is not open.") async def unsubscribe_all_async(self) -> None: """ **Unsubscribe All Async** _________________________ ------------------------ Async unsubscribe from all channels you are currently subscribed to. """ for channel, product_ids in self.subscriptions.items(): if product_ids: await self.unsubscribe_async(list(product_ids), [channel]) def sleep_with_exception_check(self, sleep: int) -> None: """ **Sleep with Exception Check** ______________________________ ------------------------ Sleep for a specified number of seconds and check for background exceptions. - **sleep** - number of seconds to sleep. """ time.sleep(sleep) self.raise_background_exception() async def sleep_with_exception_check_async(self, sleep: int) -> None: """ **Sleep with Exception Check Async** ____________________________________ ------------------------ Async sleep for a specified number of seconds and check for background exceptions. - **sleep** - number of seconds to sleep. """ await asyncio.sleep(sleep) self.raise_background_exception() def run_forever_with_exception_check(self) -> None: """ **Run Forever with Exception Check** ____________________________________ ------------------------ Runs an endless loop, checking for background exceptions every second. """ while True: time.sleep(1) self.raise_background_exception() async def run_forever_with_exception_check_async(self) -> None: """ **Run Forever with Exception Check Async** __________________________________________ ------------------------ Async runs an endless loop, checking for background exceptions every second. """ while True: await asyncio.sleep(1) self.raise_background_exception() def raise_background_exception(self) -> None: """ **Raise Background Exception** ______________________________ ------------------------ Raise any background exceptions that occurred in the message handler. """ if self._background_exception: exception_to_raise = self._background_exception self._background_exception = None raise exception_to_raise def _run_coroutine_threadsafe(self, coro): """ :meta private: """ future = asyncio.run_coroutine_threadsafe(coro, self.loop) return future.result() def _is_websocket_open(self): """ :meta private: """ return self.websocket and self.websocket.open async def _resubscribe(self): """ :meta private: """ for channel, product_ids in self.subscriptions.items(): if product_ids: await self.subscribe_async(list(product_ids), [channel]) async def _retry_connection(self): """ :meta private: """ self._retry_count = 0 @backoff.on_exception( backoff.expo, WSClientException, max_tries=self._retry_max_tries, base=self._retry_base, factor=self._retry_factor, ) async def _retry_connect_and_resubscribe(): self._retry_count += 1 logger.debug("Retrying connection attempt %s", self._retry_count) if not self._is_websocket_open(): await self.open_async() logger.debug("Resubscribing to channels") self._retry_count = 0 await self._resubscribe() return await _retry_connect_and_resubscribe() async def _message_handler(self): """ :meta private: """ self.handler_open = True while self._is_websocket_open(): try: message = await self.websocket.recv() if self.on_message: self.on_message(message) except websockets.exceptions.ConnectionClosedOK as cco: logger.debug("Connection closed (OK): %s", cco) break except websockets.exceptions.ConnectionClosedError as cce: logger.error("Connection closed (ERROR): %s", cce) if self.retry: self._retrying = True try: logger.debug("Retrying connection") await self._retry_connection() self._retrying = False except WSClientException: logger.error( "Connection closed unexpectedly. Retry attempts failed." ) self._background_exception = WSClientConnectionClosedException( "Connection closed unexpectedly. Retry attempts failed." ) self.subscriptions = {} self._retrying = False self._retry_count = 0 break else: logger.error("Connection closed unexpectedly with error: %s", cce) self._background_exception = WSClientConnectionClosedException( f"Connection closed unexpectedly with error: {cce}" ) self.subscriptions = {} break except ( websockets.exceptions.WebSocketException, json.JSONDecodeError, WSClientException, ) as e: logger.error("Exception in message handler: %s", e) self._background_exception = WSClientException( f"Exception in message handler: {e}" ) break def _build_subscription_message( self, product_ids: List[str], channel: str, message_type: str, public: bool ): """ :meta private: """ return { "type": message_type, "product_ids": product_ids, "channel": channel, **( { "jwt": jwt_generator.build_ws_jwt(self.api_key, self.api_secret), } if self.is_authenticated else {} ), } def _ensure_websocket_not_open(self): """ :meta private: """ if self._is_websocket_open(): raise WSClientException("WebSocket is already open.") def _ensure_websocket_open(self): """ :meta private: """ if not self._is_websocket_open(): raise WSClientException("WebSocket is closed or was never opened.") def _set_headers(self): """ :meta private: """ if self._retry_count > 0: return {"x-cb-retry-counter": str(self._retry_count)} else: return {}