Source code for kusp.kusp

import inspect
import typing
from typing import Tuple, Union

import numpy as np
from loguru import logger
import sys


def _ensure_default_logging(default_level: str = "WARNING") -> None:
    """
    If loguru is still in default state (single handler id=0), replace
    the default DEBUG handler with a quieter one, mostly for C++ interface.
    Otherwise defaults to debug and C++ side becoms too noisy.

    Avoids overriding when logging is already configured.
    Only catches internal-structure-related exceptions, not interrupts.
    """
    try:
        handlers = logger._core.handlers  # internal access but stable enough
        if len(handlers) == 1:
            handler = next(iter(handlers.values()))
            handler_id = getattr(handler, "_id", None)

            if handler_id == 0:
                logger.remove()
                # logger.add(sys.stderr, level=default_level)
                # print("in c++")

    except (KeyError, AttributeError, TypeError):
        return


[docs] def kusp_model( influence_distance: Union[float, np.float64, np.ndarray], species: Tuple[str, ...], strict_arg_check: bool = True, **metadata, ): """Mark a callable as a KUSP model entry point. Args: influence_distance: Cutoff distance advertised to KIM-API. species: Tuple of species symbols in the order expected by the model. strict_arg_check: Whether to raise instead of warn on signature issues. **metadata: Extra attributes stored on the decorated object. TODO: Add support for units. It is easy, just extra decorator arguments. Returns: Decorating function that adds bookkeeping attributes. """ _ensure_default_logging() def _decorator(functor): logger.debug(f"Received influence_distance: {influence_distance}, for object {functor}") if strict_arg_check: logger.debug("Using strict arg check; pass strict_arg_check=False to only warn.") # Pick target: function or class.__call__ target = functor.__call__ if inspect.isclass(functor) else functor logger.debug(f"Got the functor: {target}") if target is not None: sig = inspect.signature(target) params = list(sig.parameters.values()) user_params = params[1:] if inspect.isclass(functor) else params if len(user_params) < 3: msg = "Model must accept three parameters: (species, positions, contributing)." logger.error(msg) raise TypeError(msg) if any( p.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD) for p in user_params ): msg = "Do not use *args/**kwargs; expect exactly (species, positions, contributing)." logger.error(msg) raise TypeError(msg) try: hints = typing.get_type_hints(target) except Exception: hints = {} hints.pop("self", None) r = hints.get("return") if not ( r is not None and typing.get_origin(r) in (tuple, Tuple) and len(typing.get_args(r)) == 2 and typing.get_args(r)[0] is np.ndarray and typing.get_args(r)[1] is np.ndarray ): msg = "Missing/incorrect return type hint; expected Tuple[np.ndarray, np.ndarray]." msg += "Either provide concrete hints or pass strict_arg_check=False argument in decorator." logger.error(msg) raise TypeError(msg) # Annotate functor for KUSP functor.__kusp_model__ = True functor.__kusp_metadata__ = metadata functor.__kusp_influence_distance__ = influence_distance functor.__kusp_species__ = species logger.debug(f"All done, returning the object: {functor.__dict__}") return functor return _decorator