Source code for jinwu.model.modelbase
"""
Model base classes inspired by XSPEC:
AdditiveModel: produces a spectrum to be summed
MultiplicativeModel: produces a transmission factor to multiply
ConvolutionModel: convolves an input spectrum
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Dict, Iterable, Optional
import numpy as np
[docs]
class ModelBase(ABC):
"""Common base class for all spectral model components."""
def __init__(self, name: Optional[str] = None, params: Optional[Dict[str, float]] = None):
self.name = name or self.__class__.__name__
self.params: Dict[str, float] = params.copy() if params else {}
@property
def param_names(self) -> Iterable[str]:
return self.params.keys()
[docs]
def set_params(self, **kwargs) -> None:
for k, v in kwargs.items():
if k not in self.params:
raise KeyError(f"Unknown parameter '{k}' for model '{self.name}'")
self.params[k] = float(v)
[docs]
def __call__(self, *args, **kwargs):
return self.evaluate(*args, **kwargs)
[docs]
@abstractmethod
def evaluate(self, *args, **kwargs):
"""Return model output. Subclasses define the signature."""
raise NotImplementedError
[docs]
class AdditiveModel(ModelBase):
"""Additive component: returns a spectrum to be summed."""
[docs]
@abstractmethod
def evaluate(self, energy: np.ndarray, **kwargs) -> np.ndarray:
"""
Parameters
----------
energy : np.ndarray
Energy array (bin edges or centers).
Returns
-------
np.ndarray
Model spectrum in the same shape as energy input (or len-1 for bin edges).
"""
raise NotImplementedError
[docs]
class MultiplicativeModel(ModelBase):
"""Multiplicative component: returns a transmission factor."""
[docs]
@abstractmethod
def evaluate(self, energy: np.ndarray, **kwargs) -> np.ndarray:
"""
Parameters
----------
energy : np.ndarray
Energy array (bin edges or centers).
Returns
-------
np.ndarray
Multiplicative factor matching the spectrum shape.
"""
raise NotImplementedError
[docs]
def apply(self, energy: np.ndarray, spectrum: np.ndarray, **kwargs) -> np.ndarray:
"""Apply multiplicative factor to a given spectrum."""
factor = self.evaluate(energy, **kwargs)
return np.asarray(spectrum) * np.asarray(factor)
[docs]
class ConvolutionModel(ModelBase):
"""Convolution component: transforms an input spectrum."""
[docs]
@abstractmethod
def evaluate(self, energy: np.ndarray, spectrum: np.ndarray, **kwargs) -> np.ndarray:
"""
Parameters
----------
energy : np.ndarray
Energy array (bin edges or centers).
spectrum : np.ndarray
Input spectrum to be convolved.
Returns
-------
np.ndarray
Convolved spectrum, same shape as input spectrum.
"""
raise NotImplementedError
[docs]
def apply(self, energy: np.ndarray, spectrum: np.ndarray, **kwargs) -> np.ndarray:
"""Alias for evaluate: perform convolution on the input spectrum."""
return self.evaluate(energy, spectrum, **kwargs)