"""High-level dataset containers.
These classes wrap lower-level OGIP data structures (e.g. LightcurveData,
PhaData, EventData) and provide a uniform, higher-level interface for
selection, slicing, merging, and background handling.
The API is intentionally minimal at this stage and can be extended as
more concrete analysis needs arise.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional, Literal, Union, cast
import numpy as np
from jinwu.core.base import LightcurveDataBase, PhaBase, RegionArea
from jinwu.core.ops import rebin_lightcurve
if TYPE_CHECKING:
from jinwu.core.data import LightcurveData, PhaData
__all__ = [
"LightcurveDataset",
"SpectrumDataset",
"JointDataset",
"netdata",
]
# ---- typing helpers -------------------------------------------------------
LightcurveInput = LightcurveDataBase | PhaBase
def _coerce_lightcurve(obj: LightcurveInput, *, arg_name: str) -> LightcurveDataBase:
"""Ensure the input is a LightcurveData instance.
Accepts any OgipData but raises a clear error if the provided object
is not actually a light curve (kind != 'lc'). This allows callers who
use ``readfits`` without specifying ``kind='lc'`` to pass the result
directly without extra casting while still keeping runtime safety.
"""
# Fast-path: normal case (no module reload / single class identity)
if isinstance(obj, LightcurveDataBase):
return obj
kind = getattr(obj, "kind", None)
# Duck-typing fallback: accept any object that looks like a lightcurve.
# This keeps runtime behavior stable across reloads and avoids surprising
# failures in interactive notebooks.
if kind == "lc":
required_attrs = (
"time",
"value",
"error",
"dt",
"timezero",
"bin_exposure",
"exposure",
"region",
"is_rate",
)
if all(hasattr(obj, attr) for attr in required_attrs):
return cast(LightcurveDataBase, obj)
raise TypeError(
f"{arg_name} must be a LightcurveData (kind='lc'), got {type(obj).__name__} from {type(obj).__module__} with kind={kind!r}."
)
[docs]
@dataclass(slots=True)
class LightcurveDataset:
"""光变曲线容器类(支持多条曲线统一绘图)
参数
----
data : List[LightcurveData] | LightcurveData
单个或多个光变曲线
labels : List[str] | str | None, optional
每条曲线的标签
示例
----
>>> ds = LightcurveDataset(data=[lc1, lc2, lc3], labels=["Src", "Bkg", "Net"])
>>> ds.plot(ykind='rate', multiband=True)
>>> ds = lc1 + lc2 + lc3 # 链式创建
>>> ds = ds + lc4 # 添加新曲线
"""
data: List[LightcurveDataBase]
labels: Optional[List[str]] = None
def __post_init__(self):
"""确保 data 是列表"""
if not isinstance(self.data, list):
self.data = [self.data]
if self.labels is not None and not isinstance(self.labels, list):
self.labels = [self.labels]
if self.labels is not None and len(self.labels) != len(self.data):
raise ValueError(f"labels length ({len(self.labels)}) != data length ({len(self.data)})")
def __len__(self) -> int:
return len(self.data)
def __getitem__(self, index: int) -> LightcurveDataBase:
return self.data[index]
def __add__(self, other: Union[LightcurveDataBase, 'LightcurveDataset']) -> 'LightcurveDataset':
"""添加新的光变曲线到容器
示例
----
>>> ds = ds + new_lc
>>> ds = ds1 + ds2 # 合并两个 dataset
"""
if isinstance(other, LightcurveDataBase):
return LightcurveDataset(
data=self.data + [other],
labels=self.labels
)
elif isinstance(other, LightcurveDataset):
new_labels = None
if self.labels is not None and other.labels is not None:
new_labels = self.labels + other.labels
return LightcurveDataset(
data=self.data + other.data,
labels=new_labels
)
else:
return NotImplemented
[docs]
def plot(self, *, ax=None, ykind: Literal['auto', 'rate', 'counts', 'flux'] = 'auto',
multiband: Union[bool, str] = "auto", colors=None, title=None,
grid: bool = True, **kwargs):
"""绘制光变曲线
参数
----
multiband : bool | 'auto', default='auto'
True: 多子图模式;False: 叠加模式;'auto': 自动选择
colors : list[str], optional
每条曲线的颜色
其他参数传递给 plot_lightcurve
"""
from jinwu.core.plot import plot_lightcurve
import matplotlib.pyplot as plt
# 单条曲线:直接绘制
if len(self.data) == 1:
label = self.labels[0] if self.labels else None
color = colors[0] if colors else None
return plot_lightcurve(
self.data[0], ax=ax, ykind=ykind, multiband=multiband,
color=color, label=label, title=title, grid=grid, **kwargs
)
# 多条曲线
if multiband == True or (multiband == "auto" and len(self.data) > 3):
# 多子图模式
fig, axes = plt.subplots(len(self.data), 1, figsize=(10, 3*len(self.data)), sharex=True)
if not isinstance(axes, np.ndarray):
axes = [axes]
for i, lc in enumerate(self.data):
label = self.labels[i] if self.labels else None
color = colors[i] if colors else None
plot_lightcurve(lc, ax=axes[i], ykind=ykind, color=color, label=label, grid=grid, **kwargs)
if title:
fig.suptitle(title)
return axes
else:
# 叠加模式
if ax is None:
fig, ax = plt.subplots(figsize=(10, 6))
for i, lc in enumerate(self.data):
label = self.labels[i] if self.labels else None
color = colors[i] if colors else None
plot_lightcurve(lc, ax=ax, ykind=ykind, color=color, label=label, grid=grid, **kwargs)
if title:
ax.set_title(title)
if self.labels:
ax.legend()
return ax
[docs]
@dataclass(slots=True)
class SpectrumDataset:
"""High-level spectral dataset container.
Parameters
----------
data : PhaData
The underlying OGIP spectral data (SPECTRUM + optional EBOUNDS).
label : str, optional
A human-readable label for this spectrum (e.g. instrument, epoch).
background : SpectrumDataset | None, optional
Optional associated background spectrum.
"""
data: PhaData
label: Optional[str] = None
background: Optional["SpectrumDataset"] = None
[docs]
@dataclass(slots=True)
class JointDataset:
"""A simple container for multiple datasets.
This can hold any combination of light curves and spectra that
conceptually belong to the same astrophysical source/event.
"""
lightcurves: List[LightcurveDataset]
spectra: List[SpectrumDataset]
[docs]
def add_lightcurve(self, lc: LightcurveDataset) -> None:
self.lightcurves.append(lc)
[docs]
def add_spectrum(self, spec: SpectrumDataset) -> None:
self.spectra.append(spec)
[docs]
def netdata(
source: LightcurveInput,
background: Optional[LightcurveInput] = None,
*,
ratio: Optional[float] = None,
use_exposure_weighted_ratio: bool = True,
offset: float = 0.0,
) -> LightcurveDataBase:
"""计算净光变曲线(源 - 背景)
这是核心的背景减除函数,支持 LightcurveData 和未来的 PhaData。
所有减法操作(包括 `src - bkg`)最终都调用此函数。
参数
----
source : LightcurveData | OgipData
源光变曲线
background : LightcurveData | OgipData | None
背景光变曲线
ratio : float, optional
源背景缩放比例(None 则自动计算)
use_exposure_weighted_ratio : bool, default=True
自动计算时是否使用 (区域面积×曝光时间) 比值
offset : float, default=0.0
额外计数偏移(在计数空间减去)
返回
----
LightcurveData
净光变曲线(若无背景则返回源本身)
示例
----
>>> net = netdata(src, bkg) # 自动计算 ratio
>>> net = netdata(src, bkg, ratio=1.5) # 手动 ratio
>>> net = src - bkg # 等效于 netdata(src, bkg)
说明
----
算法流程:
1. 确定 ratio(自动或手动)
2. 对齐时间轴(若需要 rebin 背景)
3. 转换到计数空间(使用 bin_exposure)
4. 执行减法:net = src - ratio * bkg - offset
5. 误差传播:err² = src_err² + (ratio * bkg_err)²
6. 转回原始单位(rate/counts)
7. 零曝光 bin 标记为 NaN
"""
import numpy as np
src_lc = _coerce_lightcurve(source, arg_name="source")
if background is None:
return src_lc
bkg_lc = _coerce_lightcurve(background, arg_name="background")
# ========== 1. 确定 ratio ==========
if ratio is None:
if use_exposure_weighted_ratio:
src_area = src_lc.region.area if (src_lc.region and src_lc.region.area) else None
bkg_area = bkg_lc.region.area if (bkg_lc.region and bkg_lc.region.area) else None
src_exp = src_lc.exposure
bkg_exp = bkg_lc.exposure
if src_area is None or bkg_area is None or bkg_area == 0:
raise ValueError(
"Cannot infer ratio: region area missing. "
"Provide ratio explicitly or ensure both have valid region.area"
)
if src_exp is None or bkg_exp is None or bkg_exp == 0:
raise ValueError(
"Cannot infer exposure-weighted ratio: exposure missing or zero. "
"Provide ratio explicitly or ensure both have valid exposure"
)
ratio = (src_area * src_exp) / (bkg_area * bkg_exp)
else:
src_area = src_lc.region.area if (src_lc.region and src_lc.region.area) else None
bkg_area = bkg_lc.region.area if (bkg_lc.region and bkg_lc.region.area) else None
if src_area is None or bkg_area is None or bkg_area == 0:
raise ValueError("Cannot infer ratio from areas. Provide explicitly.")
ratio = src_area / bkg_area
# ========== 2. 对齐时间轴 ==========
if src_lc.time is None or bkg_lc.time is None:
raise ValueError("source/background time array is None")
src_time = np.asarray(src_lc.time, dtype=float)
bkg_time = np.asarray(bkg_lc.time, dtype=float)
bkg_aligned = bkg_lc
if not (src_time.shape == bkg_time.shape and np.allclose(src_time, bkg_time)):
from jinwu.core.ops import rebin_lightcurve
if src_lc.dt is None:
raise ValueError("Cannot align: source.dt is None")
binsize_arr = np.asarray(src_lc.dt, dtype=float)
binsize = float(np.median(binsize_arr)) if binsize_arr.ndim > 0 else float(binsize_arr)
if not np.isfinite(binsize) or binsize <= 0:
raise ValueError(f"Cannot align: invalid source binsize={binsize}")
bkg_aligned = rebin_lightcurve(
bkg_lc, binsize=binsize, method='auto',
align_ref=src_lc.timezero if src_lc.timezero else None
)
# ========== 3. 转换到计数空间 ==========
def _to_counts(lc):
"""将 LightcurveData 转为计数空间,返回 (counts, err_counts, exposure_array)"""
if lc.value is None:
raise ValueError("LightcurveData.value is None")
exp = lc.bin_exposure if lc.bin_exposure is not None else (lc.dt if lc.dt is not None else 1.0)
exp_arr = np.asarray(exp, dtype=float)
val = np.asarray(lc.value, dtype=float)
if lc.is_rate:
counts = val * exp_arr
err_counts = (np.asarray(lc.error, dtype=float) * exp_arr) if lc.error is not None else np.sqrt(np.maximum(counts, 0.0))
else:
counts = val
err_counts = np.asarray(lc.error, dtype=float) if lc.error is not None else np.sqrt(np.maximum(counts, 0.0))
return counts, err_counts, exp_arr
src_counts, src_err, src_exp_arr = _to_counts(src_lc)
bkg_counts, bkg_err, _ = _to_counts(bkg_aligned)
# ========== 4. 执行减法(计数空间) ==========
net_counts = src_counts - float(ratio) * bkg_counts - float(offset)
net_var = (src_err ** 2) + (float(ratio) ** 2) * (bkg_err ** 2)
# ========== 5. 转换回原始单位 ==========
if src_lc.is_rate:
valid_exp = np.isfinite(src_exp_arr) & (src_exp_arr > 0)
out_value = np.full_like(net_counts, np.nan, dtype=float)
out_err = np.full_like(net_counts, np.nan, dtype=float)
out_value[valid_exp] = net_counts[valid_exp] / src_exp_arr[valid_exp]
out_err[valid_exp] = np.sqrt(net_var[valid_exp]) / src_exp_arr[valid_exp]
else:
out_value = net_counts
out_err = np.sqrt(net_var)
# ========== 6. 处理零曝光 bin ==========
if src_lc.bin_exposure is not None:
zero_mask = (np.asarray(src_lc.bin_exposure, dtype=float) == 0.0)
if np.any(zero_mask):
out_value = out_value.astype(float).copy()
out_err = out_err.astype(float).copy()
out_value[zero_mask] = np.nan
out_err[zero_mask] = np.nan
# ========== 7. 构造结果 ==========
if src_lc.is_rate:
out_rate = out_value
out_rate_err = out_err
out_counts = net_counts
out_counts_err = np.sqrt(net_var)
else:
out_counts = out_value
out_counts_err = out_err
out_rate = np.full_like(net_counts, np.nan, dtype=float)
out_rate_err = np.full_like(net_counts, np.nan, dtype=float)
valid_exp = np.isfinite(src_exp_arr) & (src_exp_arr > 0)
out_rate[valid_exp] = out_counts[valid_exp] / src_exp_arr[valid_exp]
out_rate_err[valid_exp] = out_counts_err[valid_exp] / src_exp_arr[valid_exp]
bin_lo = src_lc.bin_lo
bin_hi = src_lc.bin_hi
bin_width = src_lc.bin_width
binning = src_lc.binning
if (bin_lo is None or bin_hi is None or bin_width is None) and src_lc.time is not None:
try:
lo, hi, bw = src_lc._resolve_bin_geometry()
bin_lo = lo
bin_hi = hi
bin_width = bw
if bw.size > 0 and np.allclose(bw, float(np.median(bw)), rtol=1e-8, atol=1e-12):
binning = 'uniform'
else:
binning = 'variable'
except Exception:
pass
net_cls = type(src_lc)
net = net_cls(
time=src_lc.time, value=out_value, error=out_err,
dt=src_lc.dt, exposure=src_lc.exposure, is_rate=src_lc.is_rate,
bin_exposure=src_lc.bin_exposure,
timezero=src_lc.timezero, timezero_obj=src_lc.timezero_obj,
bin_lo=bin_lo, bin_hi=bin_hi, bin_width=bin_width, binning=binning,
counts=out_counts, rate=out_rate,
counts_err=out_counts_err, rate_err=out_rate_err,
err_dist=src_lc.err_dist,
tstart=src_lc.tstart, tseg=src_lc.tseg,
gti_start=src_lc.gti_start, gti_stop=src_lc.gti_stop,
quality=src_lc.quality, fracexp=src_lc.fracexp,
backscal=src_lc.backscal, areascal=src_lc.areascal,
telescop=src_lc.telescop, timesys=src_lc.timesys, mjdref=src_lc.mjdref,
path=src_lc.path, header=src_lc.header, meta=src_lc.meta,
headers_dump=src_lc.headers_dump, region=src_lc.region,
columns=src_lc.columns, ratio=ratio,
)
return net