Source code for znsocket.client

import dataclasses
import datetime
import functools
import gzip
import json
import logging
import typing as t
import uuid

import socketio.exceptions
import typing_extensions as tyex
from redis import Redis

from znsocket import exceptions
from znsocket.abc import RefreshDataTypeDict
from znsocket.utils import handle_error, parse_url

log = logging.getLogger(__name__)


def _handle_data(data: dict):
    if "type" in data:
        if data["type"] == "set":
            return set(data["data"])
        else:
            raise TypeError(f"Can not convert type '{data['type']}'")
    return data["data"]


[docs] @dataclasses.dataclass class Client: """Client to interact with a znsocket server. The Client class provides an interface to connect to and communicate with a znsocket server using websockets. It supports Redis-like commands, automatic message chunking for large data, compression, and provides automatic reconnection capabilities. Large messages are automatically handled through compression and chunking: - Messages are compressed using gzip when they exceed the compression threshold - If compressed messages still exceed size limits, they are split into chunks - All chunking and compression operations are transparent to the user Parameters ---------- address : str The address of the znsocket server (e.g., "http://localhost:5000"). decode_responses : bool, optional Whether to decode responses from the server. Default is True. namespace : str, optional The namespace to connect to. Default is "/znsocket". refresh_callbacks : dict, optional A dictionary of callbacks to call when a refresh event is received. delay_between_calls : datetime.timedelta, optional The time to wait between calls. Default is None. retry : int, optional The number of times to retry a failed call. Default is 1. connect_wait_timeout : int, optional Timeout in seconds for connection establishment. Default is 1. max_message_size_bytes : int, optional Maximum size in bytes for a single message before chunking. If None, uses server limit or defaults to 1MB. Messages exceeding this size are automatically chunked. enable_compression : bool, optional Whether to enable gzip compression for large messages. Default is True. compression_threshold : int, optional Minimum message size in bytes to trigger compression. Default is 1024 (1KB). Attributes ---------- sio : socketio.Client The underlying socket.io client instance. Examples -------- Basic usage: >>> client = Client("http://localhost:5000") >>> client.hset("mykey", "field1", "value1") >>> client.hget("mykey", "field1") 'value1' Large data with automatic chunking: >>> import numpy as np >>> client = Client("http://localhost:5000") >>> large_data = np.random.rand(1000, 1000) # ~8MB array >>> client.hset("data", "array", large_data) # Automatically compressed/chunked >>> retrieved = client.hget("data", "array") Custom chunking configuration: >>> client = Client( ... "http://localhost:5000", ... max_message_size_bytes=500000, # 500KB limit ... enable_compression=True, ... compression_threshold=2048 # Compress messages > 2KB ... ) """ address: str decode_responses: bool = True sio: socketio.Client = dataclasses.field( default_factory=socketio.Client, repr=False, init=False ) namespace: str = "/znsocket" refresh_callbacks: dict[str, t.Callable[[t.Any], None]] = dataclasses.field( default_factory=dict ) _adapter_callbacks: dict[str, t.Callable[[tuple[list, dict]], t.Any]] = ( dataclasses.field(default_factory=dict, init=False) ) delay_between_calls: datetime.timedelta | None = None retry: int = 1 connect_wait_timeout: int = 1 max_message_size_bytes: int | None = None enable_compression: bool = True compression_threshold: int = 1024 _last_call: datetime.datetime = dataclasses.field( default_factory=datetime.datetime.now, init=False )
[docs] def register_adapter_callback( self, key: str, callback: t.Callable[[tuple[list, dict]], t.Any] ) -> None: """Register an adapter callback for a specific key. Parameters ---------- key : str The adapter key to register the callback for. callback : Callable[[tuple[list, dict]], Any] The callback function to handle adapter requests for this key. """ self._adapter_callbacks[key] = callback
[docs] def pipeline(self, *args, **kwargs) -> "Pipeline": """Create a pipeline for batching Redis commands. Parameters ---------- *args Positional arguments to pass to the Pipeline constructor. **kwargs Keyword arguments to pass to the Pipeline constructor. Returns ------- Pipeline A new Pipeline instance for batching commands. Examples -------- >>> client = Client("http://localhost:5000") >>> pipe = client.pipeline() >>> pipe.hset("key1", "field1", "value1") >>> pipe.hset("key2", "field2", "value2") >>> results = pipe.execute() """ return Pipeline(self, *args, **kwargs)
def _serialize_message(self, args: tuple, kwargs: dict) -> bytes: """Serialize message arguments to bytes with optional compression. Parameters ---------- args : tuple Positional arguments. kwargs : dict Keyword arguments. Returns ------- bytes Serialized (and optionally compressed) message as bytes. """ payload = [args, kwargs] json_bytes = json.dumps(payload, separators=(",", ":"), default=str).encode( "utf-8" ) # Apply compression if enabled and message is large enough if self.enable_compression and len(json_bytes) > self.compression_threshold: compressed_bytes = gzip.compress(json_bytes, compresslevel=6) # Only use compression if it actually reduces size if len(compressed_bytes) < len(json_bytes): log.debug( f"Compressed message from {len(json_bytes)} to {len(compressed_bytes)} bytes " f"({len(compressed_bytes) / len(json_bytes) * 100:.1f}%)" ) # Prepend compression flag (1 byte) + original size (4 bytes) return b"\x01" + len(json_bytes).to_bytes(4, "big") + compressed_bytes # Return uncompressed with flag return b"\x00" + json_bytes def _split_message_bytes( self, message_bytes: bytes, max_chunk_size: int ) -> list[bytes]: """Split message bytes into chunks. Parameters ---------- message_bytes : bytes The message to split. max_chunk_size : int Maximum size per chunk in bytes. Returns ------- list[bytes] List of byte chunks. """ chunks = [] for i in range(0, len(message_bytes), max_chunk_size): chunks.append(message_bytes[i : i + max_chunk_size]) return chunks def _deserialize_message(self, message_bytes: bytes) -> tuple[tuple, dict]: """Deserialize message bytes back to args and kwargs with compression support. Parameters ---------- message_bytes : bytes The serialized (and optionally compressed) message. Returns ------- tuple[tuple, dict] The original args and kwargs. """ if len(message_bytes) == 0: raise ValueError("Empty message bytes") # Check compression flag compression_flag = message_bytes[0:1] if compression_flag == b"\x01": # Compressed message: flag (1) + original_size (4) + compressed_data if len(message_bytes) < 5: raise ValueError("Invalid compressed message format") original_size = int.from_bytes(message_bytes[1:5], "big") compressed_data = message_bytes[5:] json_bytes = gzip.decompress(compressed_data) if len(json_bytes) != original_size: raise ValueError("Decompressed size mismatch") log.debug( f"Decompressed message from {len(compressed_data)} to {len(json_bytes)} bytes" ) elif compression_flag == b"\x00": # Uncompressed message: flag (1) + json_data json_bytes = message_bytes[1:] else: # Legacy format (no compression flag) - assume uncompressed json_bytes = message_bytes payload = json.loads(json_bytes.decode("utf-8")) return tuple(payload[0]), payload[1]
[docs] @classmethod def from_url(cls, url, namespace: str = "/znsocket", **kwargs) -> "Client": """Connect to a znsocket server using a URL. Parameters ---------- url : str The URL of the znsocket server. Should be in the format "znsocket://127.0.0.1:5000". namespace : str, optional The namespace to connect to. Default is "/znsocket". **kwargs Additional keyword arguments to pass to the Client constructor. Returns ------- Client A new Client instance connected to the specified server. Examples -------- >>> client = Client.from_url("znsocket://127.0.0.1:5000") >>> client.hset("key", "field", "value") """ return cls( address=url.replace("znsocket://", "http://"), namespace=namespace, **kwargs )
def __post_init__(self): self._setup_event_handlers() self._establish_connection() self._configure_message_size() def _setup_event_handlers(self): @self.sio.on("refresh", namespace=self.namespace) def refresh(data: RefreshDataTypeDict): for key in self.refresh_callbacks: if data["target"] == key: self.refresh_callbacks[key](data["data"]) @self.sio.on("adapter:get", namespace=self.namespace) def adapter(data: tuple[list, dict]): key = data[1]["key"] if key in self._adapter_callbacks: return self._adapter_callbacks[key](data) else: raise exceptions.ZnSocketError("No adapter callback set") def _establish_connection(self): _url, _path = parse_url(self.address) try: self.sio.connect( _url, namespaces=[self.namespace], wait=True, wait_timeout=self.connect_wait_timeout, socketio_path=f"{_path}/socket.io" if _path else "socket.io", ) except socketio.exceptions.ConnectionError as err: raise exceptions.ConnectionError(self.address) from err if not self.decode_responses: raise NotImplementedError("decode_responses=False is not supported yet") def _configure_message_size(self): server_config = self.sio.call("server_config", namespace=self.namespace) if server_config is None: raise exceptions.ZnSocketError("No response from server") if self.max_message_size_bytes is None: self.max_message_size_bytes = int( server_config.get("max_http_buffer_size", 1000000) ) # Default to 1MB if not set if server_config["max_http_buffer_size"] < self.max_message_size_bytes: log.warning( f"Server max_http_buffer_size ({server_config['max_http_buffer_size']:,} bytes) " f"is larger than client max_message_size_bytes ({self.max_message_size_bytes:,} bytes). " "This may cause issues with large messages." )
[docs] def call(self, event: str, *args, **kwargs) -> t.Any: """Call an event on the server. Automatically handles chunking for large messages that exceed the size limit. Parameters ---------- event : str The event name to call on the server. *args Positional arguments to pass to the event. **kwargs Keyword arguments to pass to the event. Returns ------- Any The response from the server. Raises ------ socketio.exceptions.TimeoutError If the server does not respond within the timeout period after all retries. """ if self.delay_between_calls: time_since_last_call = datetime.datetime.now() - self._last_call delay_needed = self.delay_between_calls - time_since_last_call if delay_needed > datetime.timedelta(0): self.sio.sleep(delay_needed.total_seconds()) self._last_call = datetime.datetime.now() # Check if message needs chunking or compression message_bytes = self._serialize_message(args, kwargs) is_compressed = message_bytes.startswith(b"\x01") if len(message_bytes) > self.max_message_size_bytes: # Use chunked transmission log.debug( f"Message size ({len(message_bytes):,} bytes) exceeds limit " f"({self.max_message_size_bytes:,} bytes). Using chunked transmission." ) return self._call_chunked(event, message_bytes) elif is_compressed: # Message was compressed but fits in single transmission return self._call_normal_bytes(event, message_bytes) else: # Small uncompressed message - use original path return self._call_normal(event, args, kwargs)
def _call_normal(self, event: str, args: tuple, kwargs: dict) -> t.Any: """Send a normal (non-chunked) message to the server.""" for idx in reversed(range(self.retry + 1)): try: return self.sio.call(event, [args, kwargs], namespace=self.namespace) except socketio.exceptions.TimeoutError: if idx == 0: raise log.warning(f"Connection error. Retrying... {idx} attempts left") self.sio.sleep(1) def _call_normal_bytes(self, event: str, message_bytes: bytes) -> t.Any: """Send a normal (non-chunked) message using pre-serialized bytes.""" for idx in reversed(range(self.retry + 1)): try: # Send compressed message as a single chunk-like message return self.sio.call( "compressed_message", {"event": event, "data": message_bytes}, namespace=self.namespace, ) except socketio.exceptions.TimeoutError: if idx == 0: raise log.warning(f"Connection error. Retrying... {idx} attempts left") self.sio.sleep(1) def _call_chunked(self, event: str, message_bytes: bytes) -> t.Any: """Send a chunked message to the server with optimized encoding.""" # Reserve space for chunk metadata (should not be more than 250 bytes) chunk_size = self.max_message_size_bytes - 250 if chunk_size <= 0: raise ValueError( "max_message_size_bytes is too small to accommodate chunk metadata" ) chunks = self._split_message_bytes(message_bytes, chunk_size) chunk_id = str(uuid.uuid4()) log.debug( f"Splitting message into {len(chunks)} chunks with ID {chunk_id} " f"(original: {len(message_bytes)} bytes)" ) self._send_all_chunks(chunks, chunk_id, event) return self._get_chunked_result(chunk_id) def _send_all_chunks(self, chunks: list, chunk_id: str, event: str) -> None: """Send all chunks with retry logic.""" for chunk_index, chunk_data in enumerate(chunks): chunk_metadata = { "chunk_id": chunk_id, "chunk_index": chunk_index, "total_chunks": len(chunks), "event": event, "data": chunk_data, "size": len(chunk_data), } self._send_single_chunk(chunk_metadata, chunk_index, len(chunks)) def _send_single_chunk( self, chunk_metadata: dict, chunk_index: int, total_chunks: int ) -> None: """Send a single chunk with retry logic.""" for idx in reversed(range(self.retry + 1)): try: log.debug( f"Sending chunk {chunk_index + 1}/{total_chunks} " f"({len(chunk_metadata['data']):,} bytes)" ) response = self.sio.call( "chunked_message", chunk_metadata, namespace=self.namespace ) if response and response.get("error"): raise exceptions.ZnSocketError( f"Chunked message failed: {response['error']}" ) log.debug(f"Chunk {chunk_index + 1}/{total_chunks} sent successfully") break except socketio.exceptions.TimeoutError: if idx == 0: raise log.warning( f"Connection error on chunk {chunk_index}. Retrying... {idx} attempts left" ) self.sio.sleep(1) def _get_chunked_result(self, chunk_id: str) -> t.Any: """Wait for final response with assembled result.""" for idx in reversed(range(self.retry + 1)): try: final_response = self.sio.call( "get_chunked_result", {"chunk_id": chunk_id}, namespace=self.namespace, ) if final_response and final_response.get("error"): raise exceptions.ZnSocketError( f"Chunked message failed: {final_response['error']}" ) log.debug( f"Received final response for chunk ID {chunk_id}: {final_response}" ) return final_response except socketio.exceptions.TimeoutError: if idx == 0: raise log.warning( f"Connection error getting chunked result. Retrying... {idx} attempts left" ) self.sio.sleep(1) def _redis_command(self, command, *args, **kwargs): """Generic handler for Redis commands.""" result = self.call(command, *args, **kwargs) if result is None: raise exceptions.ZnSocketError("No response from server") handle_error(result) return _handle_data(result) def __getattr__(self, name): """Intercepts method calls to dynamically route Redis commands.""" # Check if name corresponds to a Redis command if hasattr(Redis, name): return functools.partial(self._redis_command, name) raise AttributeError( f"'{type(self).__name__}' object has no attribute '{name}'" ) @tyex.deprecated("hmset() is deprecated. Use hset() instead.") def hmset(self, name, mapping): return self.hset(name, mapping=mapping)
@dataclasses.dataclass class Pipeline: """A pipeline of Redis commands to be executed as a batch on the server. The Pipeline class allows batching multiple Redis commands together to reduce network overhead and improve performance when executing multiple operations. Large pipelines are automatically chunked based on message size. Parameters ---------- client : Client The client to send the pipeline to. Attributes ---------- pipeline : list Internal list storing the commands to be executed. Examples -------- >>> client = Client("http://localhost:5000") >>> pipe = client.pipeline() >>> pipe.hset("key1", "field1", "value1") >>> pipe.hset("key2", "field2", "value2") >>> results = pipe.execute() """ client: Client pipeline: list = dataclasses.field(default_factory=list, init=False) def _add_to_pipeline(self, command, *args, **kwargs): """Generic handler to add Redis commands to the pipeline.""" self.pipeline.append((command, [args, kwargs])) return self def __getattr__(self, name): """Intercepts method calls to dynamically add Redis commands to the pipeline.""" # Check if the name corresponds to a Redis command if hasattr(Redis, name): return functools.partial(self._add_to_pipeline, name) raise AttributeError( f"'{type(self).__name__}' object has no attribute '{name}'" ) def _send_message(self, message) -> list: """Send a message to the server and process the response.""" result = self.client.call( "pipeline", message=message, ) if result is None: raise exceptions.ZnSocketError("No response from server") handle_error(result) return [_handle_data(res) for res in result["data"]] def execute(self): """Execute the pipeline of commands as a batch on the server. Sends all queued commands to the server and returns the results in order. Automatic chunking is handled by the client.call() method based on message size. Returns ------- list A list of results corresponding to each command in the pipeline, in order. Raises ------ exceptions.ZnSocketError If the server returns no response or an error occurs. Examples -------- >>> pipe = client.pipeline() >>> pipe.hset("key1", "field1", "value1") >>> pipe.hget("key1", "field1") >>> results = pipe.execute() >>> print(results) # [1, 'value1'] """ return self._send_message(self.pipeline)