Source code for znsocket.server

import dataclasses
import gzip
import json
import logging
import threading
import time
import typing as t
from copy import deepcopy

import eventlet.wsgi
import socketio

from znsocket.abc import RefreshDataTypeDict
from znsocket.exceptions import DataError, ResponseError

log = logging.getLogger(__name__)


[docs] @dataclasses.dataclass class Storage: """In-memory storage backend for znsocket server. The Storage class provides Redis-compatible data storage operations including hash tables, lists, sets, and basic key-value operations. All data is stored in memory using Python data structures. Parameters ---------- content : dict, optional Initial content for the storage. Default is an empty dictionary. Attributes ---------- content : dict The internal storage dictionary containing all data. Examples -------- >>> storage = Storage() >>> storage.hset("users", "user1", "John") 1 >>> storage.hget("users", "user1") 'John' """ content: dict = dataclasses.field(default_factory=dict)
[docs] def hset( self, name: str, key: t.Optional[str] = None, value: t.Optional[str] = None, mapping: t.Optional[dict] = None, items: t.Optional[list] = None, ): """Set field(s) in a hash. Parameters ---------- name : str The name of the hash. key : str, optional The field name to set. value : str, optional The value to set for the field. mapping : dict, optional A dictionary of field-value pairs to set. items : list, optional A list of alternating field-value pairs to set. Returns ------- int The number of fields that were added. Raises ------ DataError If no key-value pairs are provided or if value is None when key is provided. """ if key is None and not mapping and not items: raise DataError("'hset' with no key value pairs") if value is None and not mapping and not items: raise DataError(f"Invalid input of type {type(value)}") pieces = [] if items: pieces.extend(items) if key is not None: pieces.extend((key, value)) if mapping: for pair in mapping.items(): pieces.extend(pair) if name not in self.content: self.content[name] = {} for i in range(0, len(pieces), 2): self.content[name][pieces[i]] = pieces[i + 1] return len(pieces) // 2
[docs] def hget(self, name, key): """Get the value of a hash field. Parameters ---------- name : str The name of the hash. key : str The field name to get. Returns ------- str or None The value of the field, or None if the field does not exist. """ try: return self.content[name][key] except KeyError: return None
def hmget(self, name, keys): response = [] for key in keys: try: response.append(self.content[name][key]) except KeyError: response.append(None) return response def hkeys(self, name): try: return list(self.content[name].keys()) except KeyError: return [] def delete(self, name): try: del self.content[name] return 1 except KeyError: return 0 def exists(self, name): return 1 if name in self.content else 0 def llen(self, name): try: return len(self.content[name]) except KeyError: return 0 def rpush(self, name, value): try: self.content[name].append(value) except KeyError: self.content[name] = [value] return len(self.content[name]) def lpush(self, name, value): try: self.content[name].insert(0, value) except KeyError: self.content[name] = [value] return len(self.content[name]) def lindex(self, name, index): if index is None: raise DataError("Invalid input of type None") try: return self.content[name][index] except KeyError: return None except IndexError: return None except TypeError: # index is not an integer return None def set(self, name, value): if value is None or name is None: raise DataError("Invalid input of type None") self.content[name] = value return True def get(self, name, default=None): return self.content.get(name, default) def smembers(self, name): try: response = self.content[name] except KeyError: response = set() if not isinstance(response, set): raise ResponseError( "WRONGTYPE Operation against a key holding the wrong kind of value" ) return response def lrange(self, name, start, end): if end == -1: end = None elif end >= 0: end += 1 try: return self.content[name][start:end] except KeyError: return [] def lset(self, name, index, value): try: self.content[name][index] = value except KeyError: raise ResponseError("no such key") except IndexError: raise ResponseError("index out of range") def lrem(self, name, count, value): if count is None or value is None or name is None: raise DataError("Invalid input of type None") if count == 0: try: self.content[name] = [x for x in self.content[name] if x != value] except KeyError: return 0 else: removed = 0 while removed < count: try: self.content[name].remove(value) removed += 1 except KeyError: return 0 def sadd(self, name, value): try: self.content[name].add(value) except KeyError: self.content[name] = {value} def flushall(self): self.content.clear() def srem(self, name, value): try: self.content[name].remove(value) return 1 except KeyError: return 0 def linsert(self, name, where, pivot, value): try: index = self.content[name].index(pivot) if where == "BEFORE": self.content[name].insert(index, value) elif where == "AFTER": self.content[name].insert(index + 1, value) except KeyError: return 0 except ValueError: return -1 def hexists(self, name, key): try: return 1 if key in self.content[name] else 0 except KeyError: return 0 def hdel(self, name, key): try: del self.content[name][key] return 1 except KeyError: return 0 def hlen(self, name): try: return len(self.content[name]) except KeyError: return 0 def hvals(self, name): try: return list(self.content[name].values()) except KeyError: return [] def lpop(self, name): try: return self.content[name].pop(0) except KeyError: return None except IndexError: return None def scard(self, name): try: return len(self.content[name]) except KeyError: return 0 def hgetall(self, name): try: return self.content[name] except KeyError: return {} def copy(self, src, dst): if src == dst: return False if src not in self.content: return False if dst in self.content: return False self.content[dst] = deepcopy(self.content[src]) return True
[docs] @dataclasses.dataclass class Server: """znsocket server implementation. The Server class provides a websocket-based server that implements Redis-compatible operations with automatic support for large message handling through chunking and compression. It uses eventlet for async operations and socket.io for websocket communication. Large Message Handling: - Automatically receives and reassembles chunked messages from clients - Supports compressed message decompression using gzip - Handles both single compressed messages and multi-chunk transmissions - Provides automatic cleanup of expired chunk storage The server automatically handles three types of message transmission: 1. Normal messages: Small messages sent directly 2. Compressed messages: Large messages compressed and sent as single units 3. Chunked messages: Very large messages split into multiple chunks Parameters ---------- port : int, optional The port number to bind the server to. Default is 5000. max_http_buffer_size : int, optional Maximum size of HTTP buffer in bytes. This determines the largest single message the server can receive. Default is None (uses socket.io default of ~1MB). Messages larger than this limit must be chunked by clients. async_mode : str, optional Async mode to use ('eventlet', 'gevent', etc.). Default is None. logger : bool, optional Whether to enable logging. Default is False. storage : str, optional Storage backend to use ('memory' or 'redis'). Default is 'memory'. Examples -------- Basic server with default settings: >>> server = Server(port=5000) >>> server.run() # This will block and run the server Server with larger buffer for handling big messages: >>> server = Server( ... port=5000, ... max_http_buffer_size=10 * 1024 * 1024, # 10MB buffer ... logger=True ... ) >>> server.run() Server with Redis backend: >>> server = Server(port=5000, storage='redis') >>> server.run() """ port: int = 5000 max_http_buffer_size: t.Optional[int] = None async_mode: t.Optional[str] = None logger: bool = False storage: str = "memory"
[docs] @classmethod def from_url(cls, url: str, **kwargs) -> "Server": """Create a Server instance from a URL. Parameters ---------- url : str The URL to parse, should be in format "znsocket://host:port". **kwargs Additional keyword arguments to pass to the Server constructor. Returns ------- Server A new Server instance configured with the port from the URL. Raises ------ ValueError If the URL doesn't start with "znsocket://". """ # server url looks like "znsocket://127.0.0.1:5000" if not url.startswith("znsocket://"): raise ValueError("Invalid URL") port = int(url.split(":")[-1]) return cls(port=port, **kwargs)
[docs] def run(self) -> None: """Run the server (blocking). Starts the znsocket server and blocks until the server is stopped. The server will listen on the configured port and handle incoming websocket connections. Notes ----- This method blocks the current thread. To run the server in a non-blocking way, consider using threading or asyncio. """ sio = get_sio( max_http_buffer_size=self.max_http_buffer_size, async_mode=self.async_mode, logger=self.logger, engineio_logger=self.logger, ) # Resolve storage backend if self.storage.startswith("redis://"): import redis resolved_storage = redis.Redis.from_url(self.storage, decode_responses=True) elif self.storage == "memory": resolved_storage = None # or some custom memory-backed store else: raise ValueError(f"Unsupported storage backend: {self.storage}") # Attach events with resolved storage attach_events(sio, storage=resolved_storage, namespace="/znsocket") server_app = socketio.WSGIApp(sio) eventlet.wsgi.server(eventlet.listen(("0.0.0.0", self.port)), server_app)
def get_sio( max_http_buffer_size: t.Optional[int] = None, async_mode: t.Optional[str] = None, **kwargs, ) -> socketio.Server: if max_http_buffer_size is not None: kwargs["max_http_buffer_size"] = max_http_buffer_size if async_mode is not None: kwargs["async_mode"] = async_mode # Enable compression for better performance kwargs.setdefault("compression", True) kwargs.setdefault("compression_threshold", 1024) # Compress messages >1KB return socketio.Server(**kwargs)
[docs] def attach_events( # noqa: C901 sio: socketio.Server, namespace: str = "/znsocket", storage=None ) -> None: """Attach event handlers to a socket.io server. This function sets up all the event handlers needed for the znsocket server to respond to client requests. It handles Redis-compatible operations, pipeline commands, and adapter functionality. Parameters ---------- sio : socketio.Server The socket.io server instance to attach events to. namespace : str, optional The namespace to attach events to. Default is "/znsocket". storage : Storage, optional The storage backend to use. If None, a new Storage instance is created. Returns ------- socketio.Server The socket.io server instance with events attached. Examples -------- >>> sio = socketio.Server() >>> attach_events(sio) >>> # Now sio can handle znsocket events """ if storage is None: storage = Storage() adapter = {} rooms = set() @sio.on("*", namespace=namespace) def handle_all_events(event, sid, data): """Handle any event dynamically by mapping event name to storage method.""" args, kwargs = data if hasattr(storage, event): try: result = {"data": getattr(storage, event)(*args, **kwargs)} if isinstance(result["data"], set): result["data"] = list(result["data"]) result["type"] = "set" return result except TypeError as e: return { "error": { "msg": f"Invalid arguments for {event}: {str(e)}", "type": "TypeError", } } except Exception as e: return {"error": {"msg": str(e), "type": type(e).__name__}} else: return { "error": {"msg": f"Unknown event: {event}", "type": "UnknownEventError"} } @sio.event(namespace=namespace) def server_config(sid) -> dict: """Get the server configuration.""" config = { "max_http_buffer_size": sio.eio.max_http_buffer_size, "async_mode": sio.eio.async_mode, "namespace": namespace, } log.debug(f"Server config: {config}") return config @sio.event(namespace=namespace) def check_adapter(sid, data: tuple[list, dict]) -> bool: """Check if the adapter is available.""" key = data[1]["key"] rooms.add(key) return key in adapter @sio.event(namespace=namespace) def adapter_exists(sid, data: tuple[list, dict]) -> bool: """Check if the adapter exists.""" key = data[1]["key"] return key in adapter @sio.event(namespace=namespace) def register_adapter(sid, data: tuple[list, dict]): """Register the adapter.""" key = data[1]["key"] if key in rooms: return { "error": { "msg": f"Key {key} already exists in storage", "type": "KeyError", } } adapter[key] = sid return True @sio.event(namespace=namespace) def disconnect(sid): """Handle client disconnection and cleanup adapters.""" # Find all adapters registered by this session ID and remove them adapters_to_remove = [ key for key, adapter_sid in adapter.items() if adapter_sid == sid ] for key in adapters_to_remove: del adapter[key] # Also remove from rooms if it exists there if key in rooms: rooms.remove(key) if adapters_to_remove: print( f"Cleaned up {len(adapters_to_remove)} adapters for disconnected client {sid}: {adapters_to_remove}" ) else: print(f"Client {sid} disconnected with no adapters to clean up") @sio.on("adapter:get", namespace=namespace) def adapter_get(sid, data: tuple[list, dict]): """Get the adapter.""" key = data[1]["key"] if key not in adapter: return { "error": { "msg": f"Key {key} does not exist in storage", "type": "KeyError", } } # call the adapter and return the result return sio.call( data=data, event="adapter:get", to=adapter[key], namespace=namespace, timeout=5, ) @sio.event(namespace=namespace) def refresh(sid, data: RefreshDataTypeDict) -> None: sio.emit("refresh", data, namespace=namespace, skip_sid=sid) @sio.event(namespace=namespace) def pipeline(sid, data): args, kwargs = data message = kwargs.pop("message") results = [] for cmd in message: event = cmd[0] args = cmd[1][0] kwargs = cmd[1][1] if hasattr(storage, event): try: result = {"data": getattr(storage, event)(*args, **kwargs)} if isinstance(result["data"], set): result["data"] = list(result["data"]) result["type"] = "set" results.append(result) except TypeError as e: return { "error": { "msg": f"Invalid arguments for {event}: {str(e)}", "type": "TypeError", } } except Exception as e: return {"error": {"msg": str(e), "type": type(e).__name__}} else: return { "error": { "msg": f"Unknown event: {event}", "type": "UnknownEventError", } } return {"data": results} # Dictionary to store chunked message data chunked_messages = {} def cleanup_expired_chunks(): """Clean up expired chunked messages to prevent memory leaks.""" current_time = time.time() expired_chunks = [] for chunk_id, chunk_data in chunked_messages.items(): # Clean up chunks older than 5 minutes if current_time - chunk_data.get("created_at", 0) > 300: expired_chunks.append(chunk_id) for chunk_id in expired_chunks: del chunked_messages[chunk_id] if expired_chunks: print(f"Cleaned up {len(expired_chunks)} expired chunks") # Run cleanup every 60 seconds # TODO: look into this and maybe cleanup? cleanup_timer = threading.Timer(60.0, cleanup_expired_chunks) cleanup_timer.daemon = True cleanup_timer.start() def _initialize_chunk_storage(chunk_id, event, total_chunks): """Initialize storage for a new chunk ID.""" chunked_messages[chunk_id] = { "event": event, "total_chunks": total_chunks, "received_chunks": {}, "complete": False, "created_at": time.time(), } def _validate_and_store_chunk(chunk_id, chunk_index, chunk_bytes, chunk_size): """Validate chunk size and store the chunk data.""" if chunk_size > 0 and len(chunk_bytes) != chunk_size: return { "error": { "msg": f"Chunk {chunk_index} size mismatch", "type": "ChunkSizeError", } } chunked_messages[chunk_id]["received_chunks"][chunk_index] = chunk_bytes return None def _reassemble_message(chunk_id, total_chunks): """Reassemble chunks into complete message.""" assembled_bytes = b"" for i in range(total_chunks): if i not in chunked_messages[chunk_id]["received_chunks"]: return None, { "error": {"msg": f"Missing chunk {i}", "type": "ChunkError"} } assembled_bytes += chunked_messages[chunk_id]["received_chunks"][i] return assembled_bytes, None def _decompress_message(assembled_bytes): """Decompress and parse the assembled message.""" if len(assembled_bytes) == 0: raise ValueError("Empty assembled message") compression_flag = assembled_bytes[0:1] if compression_flag == b"\x01": if len(assembled_bytes) < 5: raise ValueError("Invalid compressed message format") original_size = int.from_bytes(assembled_bytes[1:5], "big") compressed_data = assembled_bytes[5:] json_bytes = gzip.decompress(compressed_data) if len(json_bytes) != original_size: raise ValueError("Decompressed size mismatch") elif compression_flag == b"\x00": json_bytes = assembled_bytes[1:] else: json_bytes = assembled_bytes complete_message = json.loads(json_bytes.decode("utf-8")) return complete_message[0], complete_message[1] def _execute_event(chunk_id, original_event, args, kwargs, sid): """Execute the event and store the result.""" if original_event == "pipeline": try: result = pipeline(sid, (args, kwargs)) chunked_messages[chunk_id]["result"] = result chunked_messages[chunk_id]["complete"] = True return {"status": "complete"} except Exception as e: error_result = {"error": {"msg": str(e), "type": type(e).__name__}} chunked_messages[chunk_id]["result"] = error_result chunked_messages[chunk_id]["complete"] = True return {"status": "complete"} elif hasattr(storage, original_event): try: result = {"data": getattr(storage, original_event)(*args, **kwargs)} if isinstance(result["data"], set): result["data"] = list(result["data"]) result["type"] = "set" chunked_messages[chunk_id]["result"] = result chunked_messages[chunk_id]["complete"] = True return {"status": "complete"} except TypeError as e: error_result = { "error": { "msg": f"Invalid arguments for {original_event}: {str(e)}", "type": "TypeError", } } chunked_messages[chunk_id]["result"] = error_result chunked_messages[chunk_id]["complete"] = True return {"status": "complete"} except Exception as e: error_result = {"error": {"msg": str(e), "type": type(e).__name__}} chunked_messages[chunk_id]["result"] = error_result chunked_messages[chunk_id]["complete"] = True return {"status": "complete"} else: error_result = { "error": { "msg": f"Unknown event: {original_event}", "type": "UnknownEventError", } } chunked_messages[chunk_id]["result"] = error_result chunked_messages[chunk_id]["complete"] = True return {"status": "complete"} @sio.event(namespace=namespace) def chunked_message(sid, data): """Handle chunked message fragments with optimized processing.""" chunk_id = data["chunk_id"] chunk_index = data["chunk_index"] total_chunks = data["total_chunks"] event = data["event"] chunk_bytes = data["data"] chunk_size = data.get("size", 0) log.debug( f"Received chunk {chunk_index + 1}/{total_chunks} for chunk ID: {chunk_id}, event: {event}" ) if chunk_id not in chunked_messages: _initialize_chunk_storage(chunk_id, event, total_chunks) try: error = _validate_and_store_chunk( chunk_id, chunk_index, chunk_bytes, chunk_size ) if error: return error except Exception as e: return { "error": { "msg": f"Failed to decode chunk {chunk_index}: {str(e)}", "type": "ChunkDecodeError", } } if len(chunked_messages[chunk_id]["received_chunks"]) == total_chunks: assembled_bytes, error = _reassemble_message(chunk_id, total_chunks) if error: return error try: args, kwargs = _decompress_message(assembled_bytes) original_event = chunked_messages[chunk_id]["event"] return _execute_event(chunk_id, original_event, args, kwargs, sid) except (json.JSONDecodeError, ValueError, gzip.BadGzipFile) as e: error_result = { "error": { "msg": f"Failed to deserialize message: {str(e)}", "type": "DeserializationError", } } chunked_messages[chunk_id]["result"] = error_result chunked_messages[chunk_id]["complete"] = True return {"status": "complete"} else: return { "status": "waiting", "received": len(chunked_messages[chunk_id]["received_chunks"]), "total": total_chunks, } @sio.event(namespace=namespace) def compressed_message(sid, data): """Handle compressed single messages.""" # TODO: might be possible to unify this with chunked_message event = data["event"] message_bytes = data["data"] log.debug( f"Received compressed message for event: {event}, size: {len(message_bytes)} bytes" ) try: # Decompress and parse the message (same logic as in chunked messages) args, kwargs = _decompress_message(message_bytes) # Execute the event directly if event == "pipeline": return pipeline(sid, (args, kwargs)) elif hasattr(storage, event): try: result = {"data": getattr(storage, event)(*args, **kwargs)} if isinstance(result["data"], set): result["data"] = list(result["data"]) result["type"] = "set" return result except TypeError as e: return { "error": { "msg": f"Invalid arguments for {event}: {str(e)}", "type": "TypeError", } } except Exception as e: return {"error": {"msg": str(e), "type": type(e).__name__}} else: return { "error": { "msg": f"Unknown event: {event}", "type": "UnknownEventError", } } except (json.JSONDecodeError, ValueError, gzip.BadGzipFile) as e: return { "error": { "msg": f"Failed to deserialize message: {str(e)}", "type": "DeserializationError", } } @sio.event(namespace=namespace) def get_chunked_result(sid, data): """Get the result of a completed chunked message.""" chunk_id = data["chunk_id"] log.debug(f"Retrieving result for chunk ID: {chunk_id}") if chunk_id not in chunked_messages: return { "error": { "msg": f"Chunk ID {chunk_id} not found", "type": "ChunkNotFoundError", } } chunk_data = chunked_messages[chunk_id] if not chunk_data["complete"]: return { "error": { "msg": f"Chunk ID {chunk_id} not complete", "type": "ChunkIncompleteError", } } # Get the result and clean up result = chunk_data["result"] del chunked_messages[chunk_id] return result return sio