import signal
import socket
import struct
import time
from typing import Callable, Optional, Tuple
import click
import numpy as np
from loguru import logger
from .utils import load_kusp_callable, recv_exact
def _server_message(message: str, *, fg: str = "green") -> None:
"""Emit a consistent runtime status banner."""
click.secho(f"[KUSP] [SERVER] {message}", fg=fg, bold=True)
[docs]
class IPProtocol:
"""Serve KUSP models over a TCP socket."""
_sigint_window_sec = 2.0
def __init__(
self,
host: str = "127.0.0.1",
port: int = 12345,
max_connections: int = 1,
reuse_address: bool = True,
recv_timeout_s: float = 15.0,
send_timeout_s: float = 15.0,
max_atoms: int = 1_000_000_000,
*,
model_file: Optional[str] = None,
init_kwargs: Optional[dict] = None,
) -> None:
"""Configure protocol behavior.
Args:
host: Interface to bind.
port: TCP port to bind.
max_connections: Maximum simultaneous backlog.
reuse_address: Whether to reuse a recently closed port.
recv_timeout_s: Socket timeout while receiving payloads.
send_timeout_s: Socket timeout while sending responses.
max_atoms: Hard upper bound on atoms accepted from clients.
model_file: Optional path to a decorated KUSP model.
init_kwargs: Keyword arguments forwarded to the model constructor.
"""
self.host = host
self.port = port
self.max_connections = max_connections
self.reuse_address = reuse_address
self.recv_timeout_s = recv_timeout_s
self.send_timeout_s = send_timeout_s
self.max_atoms = max_atoms
self.server_socket: Optional[socket.socket] = None
self._running = False
self._reload_requested = False
self._shutdown_requested = False
self._last_sigint_ts: float = 0.0
self._handler: Optional[
Callable[
[np.ndarray, np.ndarray, Optional[np.ndarray]],
Tuple[np.ndarray, np.ndarray],
]
] = None
self._model_file = model_file
self._init_kwargs = dict(init_kwargs or {})
def _install_sigint_handler(self) -> None:
"""Register Ctrl-C handling for reload/shutdown semantics."""
def _on_sigint(_signum, _frame):
now = time.monotonic()
if now - self._last_sigint_ts <= self._sigint_window_sec:
self._shutdown_requested = True
_server_message(
f"Two Ctrl-C within {self._sigint_window_sec:.1f}s; shutting down.",
fg="red",
)
else:
self._reload_requested = True
_server_message(
"Reloading model (Ctrl-C twice quickly to exit).",
fg="yellow",
)
self._last_sigint_ts = now
signal.signal(signal.SIGINT, _on_sigint)
def _maybe_reload(self) -> None:
"""Reload the configured model if a reload was requested."""
if not self._reload_requested:
return
self._reload_requested = False
if not self._model_file:
logger.warning(
"Reload requested but no model file configured; ignoring."
)
return
try:
self._handler = load_kusp_callable(
self._model_file, init_kwargs=self._init_kwargs
)
_server_message(f"{self._model_file} reloaded")
except Exception as exc:
_server_message(
f"Failed to reload {self._model_file}; keeping previous handler.",
fg="red",
)
[docs]
def start(self, on_ready: Optional[Callable[[], None]] = None) -> None:
"""Bind the listening socket and start accepting clients.
Args:
on_ready: Optional callback invoked once the server is bound.
"""
if self.server_socket is not None:
return
server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
if self.reuse_address:
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
server_socket.settimeout(0.5) # poll reload/shutdown between accepts
server_socket.bind((self.host, self.port))
server_socket.listen(self.max_connections)
self.server_socket = server_socket
self._running = True
_server_message(f"TCP server listening on {self.host}:{self.port}")
self._install_sigint_handler()
if on_ready is not None:
on_ready()
[docs]
def stop(self) -> None:
"""Shut down the server socket and restore the default SIGINT handler."""
self._running = False
if self.server_socket is not None:
try:
self.server_socket.close()
logger.info("KUSP TCP server stopped")
finally:
self.server_socket = None
try:
signal.signal(signal.SIGINT, signal.default_int_handler)
except Exception:
pass
[docs]
def serve(self, handler: Optional[Callable] = None) -> None:
"""Run the main accept/response loop.
Args:
handler: Optional callable overriding the configured model.
Raises:
RuntimeError: If `start` has not been called or no handler exists.
"""
if self.server_socket is None:
raise RuntimeError("IPProtocol.start must be called before serve.")
if handler is not None:
self._handler = handler
elif self._model_file:
self._handler = load_kusp_callable(
self._model_file, init_kwargs=self._init_kwargs
)
else:
raise RuntimeError(
"No handler provided and no model_file configured."
)
fmt_map = {4: ("i", np.int32), 8: ("q", np.int64)}
try:
while self._running and not self._shutdown_requested:
self._maybe_reload()
try:
client_socket, client_address = self.server_socket.accept()
except socket.timeout:
continue
except OSError:
if not self._running or self._shutdown_requested:
break
raise
with client_socket:
client_socket.settimeout(self.recv_timeout_s)
logger.info(f"Client connected from {client_address}")
while self._running and not self._shutdown_requested:
self._maybe_reload()
try:
header = recv_exact(client_socket, 4)
logger.debug(f"Received header: {header}")
except ConnectionError:
logger.debug(
f"Client {client_address} disconnected (no header)."
)
break
try:
int_width = struct.unpack("i", header)[0]
logger.debug(f"Received int_width: {int_width}")
except struct.error as exc:
logger.warning(
f"Malformed int-width header from {client_address}: {exc}"
)
break
if int_width not in fmt_map:
logger.warning(
f"Unsupported integer width {int_width} from {client_address}"
)
break
int_fmt, int_type = fmt_map[int_width]
try:
n_atoms_bytes = recv_exact(client_socket, int_width)
n_atoms = struct.unpack(int_fmt, n_atoms_bytes)[0]
logger.debug(f"Received n_atoms: {n_atoms}")
except (ConnectionError, struct.error) as exc:
logger.warning(
f"Failed reading n_atoms from {client_address}: {exc}"
)
break
if n_atoms <= 0 or n_atoms > self.max_atoms:
logger.warning(
f"Invalid n_atoms={n_atoms} from {client_address}; closing."
)
break
numbers_nbytes = int_width * n_atoms
coords_nbytes = 8 * 3 * n_atoms
contrib_nbytes = int_width * n_atoms
try:
numbers_bytes = recv_exact(
client_socket, numbers_nbytes
)
coords_bytes = recv_exact(
client_socket, coords_nbytes
)
contributing_bytes = recv_exact(
client_socket, contrib_nbytes
)
except ConnectionError as exc:
logger.warning(
f"Client {client_address} disconnected mid-payload: {exc}"
)
break
try:
Z = np.frombuffer(
numbers_bytes, dtype=int_type
).copy()
R = (
np.frombuffer(coords_bytes, dtype=np.float64)
.copy()
.reshape((n_atoms, 3))
)
M = np.frombuffer(
contributing_bytes, dtype=int_type
).copy()
logger.debug(f"Received Arrays:\n{Z}\n{R}\n{M}")
except ValueError as exc:
logger.warning(
f"Shape/dtype error from {client_address}: {exc}"
)
break
logger.debug(
f"Received arrays for species, positions, contributing particles:\n{Z}\n{R}\n{M}"
)
t0 = time.perf_counter()
try:
current = self._handler
if current is None:
raise RuntimeError(
"No active handler available"
)
energy, forces = current(Z, R, M)
logger.debug(f"Energy: {energy}, Forces: {forces}")
except Exception as exc:
logger.exception(
f"KUSP handler raised an exception: {exc}"
)
break
t1 = time.perf_counter()
elapsed_ms = (t1 - t0) * 1000.0
try:
client_socket.settimeout(self.send_timeout_s)
client_socket.sendall(energy.tobytes())
client_socket.sendall(forces.tobytes())
except (socket.timeout, OSError) as exc:
logger.warning(
f"Send failed to {client_address}: {exc}"
)
break
logger.info(
f"Evaluated N={n_atoms} in {elapsed_ms:.2f} ms"
)
logger.debug(f"Connection from {client_address} closed.")
finally:
self.stop()