from __future__ import annotations
import importlib
from pathlib import Path
from typing import Any, Dict, Literal, Optional, Tuple, Union, cast
import numpy as np
from .base import ArfBase, RmfBase, PhaBase, LightcurveDataBase, EventDataBase, FitsHeaderDump
from .ogip import OgipFitsBase, ValidationReport
def _normalize_gti(gti: Any) -> np.ndarray:
arr = np.asarray(gti, dtype=float)
if arr.size == 0:
return np.asarray([], dtype=float).reshape(0, 2)
arr = arr.reshape(-1, 2)
good = np.isfinite(arr[:, 0]) & np.isfinite(arr[:, 1]) & (arr[:, 1] > arr[:, 0])
arr = arr[good]
if arr.size == 0:
return np.asarray([], dtype=float).reshape(0, 2)
order = np.argsort(arr[:, 0])
return arr[order]
def _merge_gti(gti: np.ndarray) -> np.ndarray:
arr = _normalize_gti(gti)
if arr.shape[0] <= 1:
return arr
merged: list[list[float]] = [[float(arr[0, 0]), float(arr[0, 1])]]
for s, e in arr[1:]:
ls, le = merged[-1]
if float(s) <= le:
if float(e) > le:
merged[-1][1] = float(e)
else:
merged.append([float(s), float(e)])
return np.asarray(merged, dtype=float)
def _create_gti_mask_fallback(time: Any, gti: Any, dt: Optional[float] = None) -> np.ndarray:
t = np.asarray(time, dtype=float).reshape(-1)
mask = np.zeros(t.shape, dtype=bool)
g = _normalize_gti(gti)
if t.size == 0 or g.size == 0:
return mask
dt_val = None
if dt is not None:
try:
dt_f = float(dt)
if np.isfinite(dt_f) and dt_f > 0:
dt_val = dt_f
except Exception:
dt_val = None
if dt_val is None:
for s, e in g:
mask |= (t >= float(s)) & (t < float(e))
else:
t1 = t + dt_val
for s, e in g:
# bin 与 GTI 有重叠即保留
mask |= (t1 > float(s)) & (t < float(e))
return mask
def _create_gti_mask(time: Any, gti: Any, dt: Optional[float] = None) -> np.ndarray:
try:
mod = importlib.import_module('stingray.gti')
sr_create_gti_mask = getattr(mod, 'create_gti_mask')
return np.asarray(sr_create_gti_mask(time, gti, dt=dt), dtype=bool)
except Exception:
return _create_gti_mask_fallback(time, gti, dt=dt)
def _join_gtis_fallback(gti_a: Any, gti_b: Any) -> np.ndarray:
a = _normalize_gti(gti_a)
b = _normalize_gti(gti_b)
if a.size == 0 and b.size == 0:
return np.asarray([], dtype=float).reshape(0, 2)
if a.size == 0:
return b
if b.size == 0:
return a
return _merge_gti(np.vstack([a, b]))
def _join_gtis(gti_a: Any, gti_b: Any) -> np.ndarray:
try:
mod = importlib.import_module('stingray.gti')
sr_join_gtis = getattr(mod, 'join_gtis')
joined = np.asarray(sr_join_gtis(np.asarray(gti_a), np.asarray(gti_b)), dtype=float)
if joined.size == 0:
return np.asarray([], dtype=float).reshape(0, 2)
return joined.reshape(-1, 2)
except Exception:
return _join_gtis_fallback(gti_a, gti_b)
def _validate_time_series_like(obj: Any) -> ValidationReport:
rpt = OgipFitsBase.validate(obj)
for k in obj.REQUIRED_KEYS:
if not obj._has_key_ci(k):
lvl = 'ERROR' if k in getattr(obj, 'CRITICAL_KEYS', ()) else 'WARN'
rpt.add(lvl, 'MISSING_KEY', f"Required key '{k}' not found (OGIP-93-003).")
cols = getattr(obj, 'columns', ()) or ()
colset = {c.upper() for c in cols}
for group in obj.REQUIRED_COLUMNS_ANY:
if not any(c.upper() in colset for c in group):
rpt.add('ERROR', 'MISSING_COLUMN', f"Missing required column group: {group}")
if not (obj._has_key_ci('MJDREF') or obj._has_any_key_ci(['MJDREFI', 'MJDREFF'])):
rpt.add('WARN', 'MISSING_MJDREF', 'MJDREF / (MJDREFI+MJDREFF) not found in header; absolute times may be ambiguous.')
try:
timeunit = obj.get_keyword_ci('TIMEUNIT')
except Exception:
timeunit = (obj.header or {}).get('TIMEUNIT')
if timeunit is not None and str(timeunit).upper() not in ('S', 'SEC', 'SECOND', 'SECONDS'):
rpt.add('INFO', 'UNUSUAL_TIMEUNIT', f"TIMEUNIT='{timeunit}'")
return rpt
[docs]
class ArfData(ArfBase):
"""Concrete ARF data class with local behavior implementation."""
[docs]
def validate(self) -> ValidationReport: # type: ignore[override]
rpt = OgipFitsBase.validate(self)
hdr = self.header
for group in self.REQUIRED_KEYS_ANY:
if not any(g in hdr for g in group):
rpt.add('WARN', 'MISSING_KEY', f"Missing one of required keys {group} (CAL/GEN/92-002).")
colset = {c.upper() for c in self.columns}
for c in ["ENERG_LO", "ENERG_HI", "SPECRESP"]:
if c not in colset:
rpt.add('ERROR', 'MISSING_COLUMN', f"ARF missing column {c}")
rpt.ok = len(rpt.errors()) == 0
return rpt
[docs]
def plot(
self,
ax: Optional[Any] = None,
*,
energy_unit: str = 'keV',
yscale: str = 'linear',
title: Optional[str] = None,
**kwargs,
):
import matplotlib.pyplot as _plt
mid_e = 0.5 * (np.asarray(self.energ_lo) + np.asarray(self.energ_hi))
ax = ax or _plt.gca()
assert ax is not None
kwargs.setdefault('lw', 1.0)
ax.plot(mid_e, self.specresp, **kwargs)
ax.set_xlabel(f"Energy ({energy_unit})")
ax.set_ylabel("Effective Area (cm$^2$)")
ax.set_yscale(yscale)
if title is None:
title = Path(str(getattr(self, 'path', ''))).name or 'ARF'
ax.set_title(title)
ax.grid(alpha=0.3, ls='--')
return ax
[docs]
def rebin(self, factor: int) -> 'ArfData':
if factor <= 0:
raise ValueError('factor must be > 0')
from ..ftools.ftrbnrmf import rebin_arf
elo = np.asarray(self.energ_lo, dtype=float)
ehi = np.asarray(self.energ_hi, dtype=float)
area = np.asarray(self.specresp, dtype=float)
n = elo.size
groups = [(i, min(i + factor, n)) for i in range(0, n, factor)]
new_elo = np.array([elo[s] for s, _ in groups], dtype=float)
new_ehi = np.array([ehi[e - 1] for _, e in groups], dtype=float)
new_area = rebin_arf(elo, ehi, area, new_elo, new_ehi)
return ArfData(
path=self.path,
energ_lo=new_elo,
energ_hi=new_ehi,
specresp=new_area,
columns=self.columns,
header=self.header,
meta=self.meta,
headers_dump=self.headers_dump,
)
[docs]
class RmfData(RmfBase):
"""Concrete RMF data class with local behavior implementation."""
[docs]
def rebuild_dense(self) -> np.ndarray:
if self.channel is None or self.e_min is None or self.e_max is None:
if self.matrix.ndim == 2:
return self.matrix
raise ValueError("Cannot rebuild dense RMF without EBOUNDS and sparse definition")
n_e = self.energ_lo.size
n_c = int(self.channel.size)
out = np.zeros((n_e, n_c), dtype=float)
if (self.n_grp is not None) and (self.f_chan is not None) and (self.n_chan is not None):
for i in range(n_e):
ng = int(self.n_grp[i]) if np.ndim(self.n_grp) > 0 else int(self.n_grp)
row_vals = self.matrix[i]
if isinstance(row_vals, np.ndarray):
cursor = 0
fcs = np.atleast_1d(self.f_chan[i]) if np.ndim(self.f_chan) > 1 else np.atleast_1d(self.f_chan)
ncs = np.atleast_1d(self.n_chan[i]) if np.ndim(self.n_chan) > 1 else np.atleast_1d(self.n_chan)
for g in range(ng):
start = int(fcs[g])
width = int(ncs[g])
out[i, start:start + width] = row_vals[cursor:cursor + width]
cursor += width
else:
if np.ndim(self.matrix) == 2 and self.matrix.shape[1] == n_c:
out[i, :] = self.matrix[i, :]
else:
raise ValueError("Unsupported RMF MATRIX layout")
return out
if self.matrix.ndim == 2 and self.matrix.shape[1] == n_c:
return self.matrix
raise ValueError("RMF does not contain recognizable sparse columns and MATRIX is not dense")
[docs]
def validate(self) -> ValidationReport: # type: ignore[override]
rpt = OgipFitsBase.validate(self)
hdr = self.header
for group in self.REQUIRED_KEYS_ANY:
if not any(g in hdr for g in group):
rpt.add('WARN', 'MISSING_KEY', f"Missing one of required keys {group} (CAL/GEN/92-002).")
colset = {c.upper() for c in self.columns}
for c in ["ENERG_LO", "ENERG_HI", "MATRIX"]:
if c not in colset:
rpt.add('ERROR', 'MISSING_COLUMN', f"RMF missing column {c}")
if self.channel is not None:
for c in ["CHANNEL", "E_MIN", "E_MAX"]:
if c not in colset:
rpt.add('WARN', 'MISSING_COLUMN', f"EBOUNDS expected column {c}")
rpt.ok = len(rpt.errors()) == 0
return rpt
@property
def dense_matrix(self) -> np.ndarray:
return self.rebuild_dense()
[docs]
def plot(
self,
ax: Optional[Any] = None,
*,
kind: str = 'matrix',
row: int = 0,
yscale: str = 'linear',
cmap: str = 'viridis',
title: Optional[str] = None,
**kwargs,
):
import matplotlib.pyplot as _plt
ax = ax or _plt.gca()
assert ax is not None
if kind == 'matrix':
dm = self.dense_matrix
im = ax.imshow(dm, aspect='auto', origin='lower', cmap=cmap, **kwargs)
ax.set_xlabel('Channel')
ax.set_ylabel('Energy bin')
if title is None:
title = Path(str(getattr(self, 'path', ''))).name or 'RMF'
ax.set_title(title)
ax.figure.colorbar(im, ax=ax, fraction=0.046, pad=0.04, label='Prob.')
else:
dm = self.dense_matrix
row = max(0, min(dm.shape[0] - 1, int(row)))
kwargs.setdefault('lw', 1.0)
ax.plot(dm[row], **kwargs)
ax.set_xlabel('Channel')
ax.set_ylabel('Probability')
ax.set_yscale(yscale)
if title is None:
title = f"RMF row {row}"
ax.set_title(title)
ax.grid(alpha=0.3, ls='--')
return ax
[docs]
def rebin(self, factor: int) -> 'RmfData':
if factor <= 0:
raise ValueError('factor must be > 0')
dense = self.rebuild_dense()
n_e, _ = dense.shape
groups = [(i, min(i + factor, n_e)) for i in range(0, n_e, factor)]
new_rows = []
new_elo = []
new_ehi = []
for s, e in groups:
new_rows.append(np.sum(dense[s:e, :], axis=0))
new_elo.append(float(self.energ_lo[s]))
new_ehi.append(float(self.energ_hi[e - 1]))
new_matrix = np.vstack(new_rows)
return RmfData(
path=self.path,
energ_lo=np.asarray(new_elo, dtype=float),
energ_hi=np.asarray(new_ehi, dtype=float),
n_grp=None,
f_chan=None,
n_chan=None,
matrix=new_matrix,
channel=self.channel,
e_min=self.e_min,
e_max=self.e_max,
columns=self.columns,
header=self.header,
meta=self.meta,
headers_dump=self.headers_dump,
)
[docs]
class PhaData(PhaBase):
"""Concrete PHA data class with local behavior implementation."""
[docs]
def validate(self) -> ValidationReport: # type: ignore[override]
rpt = OgipFitsBase.validate(self)
for k in self.REQUIRED_KEYS:
if not self._has_key_ci(k):
rpt.add('WARN', 'MISSING_KEY', f"Required key '{k}' not found (OGIP-92-007).")
colset = {c.upper() for c in self.columns}
for c in ['CHANNEL']:
if c not in colset:
rpt.add('ERROR', 'MISSING_COLUMN', f"PHA missing column {c}")
if not any(c in colset for c in ('COUNTS', 'RATE')):
rpt.add('ERROR', 'MISSING_COLUMN', "PHA missing one of COUNTS/RATE")
exp_val = self.get_keyword_ci('EXPOSURE', self.get_keyword_ci('EXPTIME', None))
if exp_val is not None:
try:
if float(exp_val) <= 0:
rpt.add('WARN', 'BAD_EXPOSURE', f"Non-positive exposure value: {exp_val}")
except Exception:
rpt.add('WARN', 'BAD_EXPOSURE', f"Exposure not numeric: {exp_val}")
rpt.ok = len(rpt.errors()) == 0
return rpt
@property
def count_rate(self) -> Optional[np.ndarray]:
if self.rate is not None:
return np.asarray(self.rate, dtype=float)
if self.exposure and self.exposure > 0:
return self.counts / self.exposure
return None
[docs]
def plot(self, **kwargs):
from .plot import plot_spectrum
return plot_spectrum(self, **kwargs)
[docs]
def slice(
self,
*,
emin: Optional[float] = None,
emax: Optional[float] = None,
ch_lo: Optional[int] = None,
ch_hi: Optional[int] = None,
) -> 'PhaData':
from .ops import slice_pha
return cast(PhaData, slice_pha(cast(Any, self), emin=emin, emax=emax, ch_lo=ch_lo, ch_hi=ch_hi))
[docs]
def rebin(self, *, factor: Optional[int] = None, min_counts: Optional[float] = None) -> 'PhaData':
from .ops import rebin_pha
return cast(PhaData, rebin_pha(cast(Any, self), factor=factor, min_counts=min_counts))
[docs]
def grppha(
self,
*,
min_counts: Optional[float] = None,
groupfile: Optional[str] = None,
rebin: bool = False,
outfile: Optional[str] = None,
overwrite: bool = False,
) -> 'PhaData':
from ..ftools.grppha import grppha as _grppha
return cast(PhaData, _grppha(
cast(Any, self),
min_counts=min_counts,
groupfile=groupfile,
rebin=rebin,
outfile=outfile,
overwrite=overwrite,
))
[docs]
def group_by_min_counts(self, min_counts: float) -> np.ndarray:
from ..ftools.grppha import compute_grouping_by_min_counts
return compute_grouping_by_min_counts(self.counts, min_counts)
[docs]
class LightcurveData(LightcurveDataBase):
"""Concrete lightcurve data class with local behavior implementation."""
[docs]
def validate(self) -> ValidationReport: # type: ignore[override]
rpt = _validate_time_series_like(self)
colset = {c.upper() for c in self.columns}
if "TIME" not in colset:
rpt.add('ERROR', 'MISSING_COLUMN', "Lightcurve missing TIME column")
if not any(x in colset for x in ["RATE", "COUNTS"]):
rpt.add('ERROR', 'MISSING_COLUMN', "Lightcurve missing RATE/COUNTS column")
if self.time is None or len(self.time) == 0:
rpt.add('ERROR', 'MISSING_TIME', "time array is empty or None")
if self.time_raw is not None and self.time is not None and len(self.time_raw) != len(self.time):
rpt.add('ERROR', 'LENGTH_MISMATCH', f"time_raw ({len(self.time_raw)}) and time ({len(self.time)}) length mismatch")
if self.time_rel is not None and self.time is not None and len(self.time_rel) != len(self.time):
rpt.add('ERROR', 'LENGTH_MISMATCH', f"time_rel ({len(self.time_rel)}) and time ({len(self.time)}) length mismatch")
if self.value is None or len(self.value) == 0:
rpt.add('ERROR', 'MISSING_VALUE', "value array is empty or None")
if self.time is not None and self.value is not None and len(self.time) != len(self.value):
rpt.add('ERROR', 'LENGTH_MISMATCH', f"time ({len(self.time)}) and value ({len(self.value)}) length mismatch")
if self.gti_start is not None and self.gti_stop is not None:
if len(self.gti_start) != len(self.gti_stop):
rpt.add('ERROR', 'GTI_MISMATCH', "gti_start and gti_stop must have same length")
if len(self.gti_start) > 0 and not np.all(self.gti_start < self.gti_stop):
rpt.add('ERROR', 'GTI_ORDER', "gti_start must be < gti_stop")
if self.time is not None and len(self.time) > 1 and not np.all(np.diff(self.time) > 0):
rpt.add('WARN', 'TIME_NOT_SORTED', "time array is not strictly increasing")
if self.exposure is not None and self.exposure <= 0:
rpt.add('WARN', 'BAD_EXPOSURE', f"exposure must be > 0, got {self.exposure}")
try:
_, _, width = self._resolve_bin_geometry()
if self.time is not None and len(self.time) != len(width):
rpt.add('ERROR', 'BIN_LENGTH_MISMATCH', "time and bin geometry length mismatch")
if np.any(~np.isfinite(width)) or np.any(width <= 0):
rpt.add('ERROR', 'BAD_BIN_WIDTH', "bin widths must be finite and > 0")
except Exception as e:
rpt.add('WARN', 'BIN_GEOMETRY_UNRESOLVED', f"failed to resolve bin geometry: {e}")
if self.bin_exposure is not None and self.time is not None and len(self.bin_exposure) != len(self.time):
rpt.add('ERROR', 'BIN_EXPOSURE_MISMATCH', "bin_exposure length must match time length")
rpt.ok = len(rpt.errors()) == 0
return rpt
@property
def n(self) -> int:
return len(self.time) if self.time is not None else 0
@property
def absolute_time(self) -> np.ndarray:
tref = self.time_rel if self.time_rel is not None else self.time
if tref is None:
return np.asarray([], dtype=float)
return np.asarray(tref, dtype=float) + float(self.timezero)
@property
def gti(self) -> Optional[list[tuple[float, float]]]:
if self.gti_start is None or self.gti_stop is None:
return None
return [(float(s), float(e)) for s, e in zip(self.gti_start, self.gti_stop)]
[docs]
def get_time_object(self, index: Optional[int] = None) -> Optional[Any]:
if self.timezero_obj is None:
return None
from .time import TimeDelta
absolute_times = self.absolute_time
dt = TimeDelta(absolute_times[index], format='sec') if index is not None else TimeDelta(absolute_times, format='sec')
return self.timezero_obj + dt
@property
def bin_centers(self) -> Optional[np.ndarray]:
try:
left, right, _ = self._resolve_bin_geometry()
return 0.5 * (left + right)
except Exception:
return None
def _resolve_bin_geometry(self) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
if self.time is None:
return np.asarray([], float), np.asarray([], float), np.asarray([], float)
t = np.asarray(self.time, dtype=float)
n = t.size
if n == 0:
return np.asarray([], float), np.asarray([], float), np.asarray([], float)
if self.bin_lo is not None and self.bin_hi is not None:
lo = np.asarray(self.bin_lo, dtype=float)
hi = np.asarray(self.bin_hi, dtype=float)
if lo.shape == hi.shape == t.shape:
return lo, hi, hi - lo
if self.bin_width is not None:
bw = np.asarray(self.bin_width, dtype=float)
if bw.shape == t.shape:
return t - 0.5 * bw, t + 0.5 * bw, bw
if self.dt is not None:
dt_arr = np.asarray(self.dt, dtype=float)
if dt_arr.ndim == 0:
dt_val = float(dt_arr)
if np.isfinite(dt_val) and dt_val > 0:
bw = np.full_like(t, dt_val, dtype=float)
return t - 0.5 * bw, t + 0.5 * bw, bw
elif dt_arr.shape == t.shape:
bw = dt_arr
return t - 0.5 * bw, t + 0.5 * bw, bw
est = float(np.median(np.diff(t))) if n >= 2 else 1.0
bw = np.full_like(t, est, dtype=float)
return t - 0.5 * bw, t + 0.5 * bw, bw
@property
def mean_rate(self) -> Optional[float]:
if self.rate is None or len(self.rate) == 0:
return None
return float(np.mean(self.rate))
@property
def mean_counts(self) -> Optional[float]:
if self.counts is None or len(self.counts) == 0:
return None
return float(np.mean(self.counts))
@property
def total_counts(self) -> Optional[float]:
if self.counts is None:
return None
return float(np.sum(self.counts))
@property
def _legacy_value(self) -> Optional[np.ndarray]:
import warnings
warnings.warn("LightcurveData.value is deprecated; use counts or rate instead.", DeprecationWarning, stacklevel=2)
if self.counts is not None:
return self.counts
if self.rate is not None and self.dt is not None and np.any(self.dt > 0):
return self.rate * (self.dt if isinstance(self.dt, np.ndarray) else self.dt)
return None
@property
def _legacy_error(self) -> Optional[np.ndarray]:
import warnings
warnings.warn("LightcurveData.error is deprecated; use counts_err or rate_err instead.", DeprecationWarning, stacklevel=2)
if self.counts_err is not None:
return self.counts_err
if self.rate_err is not None and self.dt is not None:
return self.rate_err * self.dt
return None
@property
def _legacy_is_rate(self) -> bool:
import warnings
warnings.warn("LightcurveData.is_rate is deprecated; check counts/rate fields instead.", DeprecationWarning, stacklevel=2)
return self.rate is not None and (self.counts is None or self.counts.sum() == 0)
[docs]
def apply_gti(self, inplace: bool = False) -> 'LightcurveData':
if self.gti is None:
return self if inplace else self._copy()
gti_arr = np.array(self.gti)
bin_lo, _, _ = self._resolve_bin_geometry()
mask = _create_gti_mask(bin_lo, gti_arr, dt=self.dt if isinstance(self.dt, float) else None)
return self.apply_mask(np.asarray(mask, dtype=bool), inplace=inplace)
[docs]
def split_by_gti(self, min_points: int = 1) -> list['LightcurveData']:
if self.gti is None or len(self.gti) == 0:
return [self]
result = []
for start, stop in self.gti:
gti_arr = np.array([[start, stop]])
bin_lo, _, _ = self._resolve_bin_geometry()
mask = _create_gti_mask(bin_lo, gti_arr, dt=self.dt if isinstance(self.dt, float) else None)
mask = np.asarray(mask, dtype=bool)
if np.sum(mask) < min_points:
continue
lc_segment = self.apply_mask(mask, inplace=False)
lc_segment.gti_start = np.array([start])
lc_segment.gti_stop = np.array([stop])
result.append(lc_segment)
return result if result else [self]
[docs]
def apply_mask(self, mask: np.ndarray, inplace: bool = False, filtered_attrs: Optional[list[str]] = None) -> 'LightcurveData':
all_array_attrs = [
'bin_lo', 'bin_hi', 'bin_width', 'counts', 'rate', 'counts_err', 'rate_err',
'quality', 'fracexp', 'bin_exposure'
]
if filtered_attrs is None:
filtered_attrs = all_array_attrs
lc_new = self if inplace else self._copy()
if lc_new.time is not None:
lc_new.time = np.asanyarray(lc_new.time)[mask]
if lc_new.value is not None:
lc_new.value = np.asanyarray(lc_new.value)[mask]
if lc_new.error is not None:
lc_new.error = np.asanyarray(lc_new.error)[mask]
for attr in all_array_attrs:
val = getattr(lc_new, attr, None)
if val is None:
continue
if attr not in filtered_attrs:
setattr(lc_new, attr, None)
else:
try:
setattr(lc_new, attr, np.asanyarray(val)[mask])
except (IndexError, TypeError):
pass
try:
lo, hi, bw = lc_new._resolve_bin_geometry()
if lo.size > 0:
lc_new.tstart = float(lo[0])
lc_new.tseg = float(hi[-1] - lo[0])
lc_new.bin_lo = lo
lc_new.bin_hi = hi
lc_new.bin_width = bw
if np.allclose(bw, float(np.median(bw)), rtol=1e-8, atol=1e-12):
lc_new.binning = 'uniform'
else:
lc_new.binning = 'variable'
except Exception:
pass
return lc_new
[docs]
def join(self, other: 'LightcurveData', skip_checks: bool = False) -> 'LightcurveData':
import copy
import warnings
if self.mjdref is not None and other.mjdref is not None and self.mjdref != other.mjdref:
warnings.warn(
f"MJDref mismatch: self={self.mjdref}, other={other.mjdref}. Converting other to self's mjdref.",
UserWarning,
)
other = copy.deepcopy(other)
assert other.mjdref is not None and self.mjdref is not None
time_offset = (other.mjdref - self.mjdref) * 86400.0
o_lo, o_hi, o_bw = other._resolve_bin_geometry()
other.bin_lo = o_lo + time_offset
other.bin_hi = o_hi + time_offset
other.bin_width = o_bw
if other.gti_start is not None:
other.gti_start = other.gti_start + time_offset
if other.gti_stop is not None:
other.gti_stop = other.gti_stop + time_offset
other.mjdref = self.mjdref
if self.tstart is not None and other.tstart is not None and self.tstart < other.tstart:
first_lc, second_lc = self, other
else:
first_lc, second_lc = other, self
self_lo, _, _ = self._resolve_bin_geometry()
other_lo, _, _ = other._resolve_bin_geometry()
if len(np.intersect1d(self_lo, other_lo)) > 0:
warnings.warn("The two light curves have overlapping time ranges. In overlapping regions, counts will be summed.", UserWarning)
lo1, hi1, bw1 = first_lc._resolve_bin_geometry()
lo2, hi2, bw2 = second_lc._resolve_bin_geometry()
new_bin_lo = np.concatenate([lo1, lo2])
new_bin_hi = np.concatenate([hi1, hi2])
new_bin_width = np.concatenate([bw1, bw2])
new_counts = np.concatenate([first_lc.counts, second_lc.counts]) if (first_lc.counts is not None and second_lc.counts is not None) else None
new_rate = np.concatenate([first_lc.rate, second_lc.rate]) if (first_lc.rate is not None and second_lc.rate is not None) else None
new_counts_err = np.concatenate([first_lc.counts_err, second_lc.counts_err]) if (first_lc.counts_err is not None and second_lc.counts_err is not None) else None
new_rate_err = np.concatenate([first_lc.rate_err, second_lc.rate_err]) if (first_lc.rate_err is not None and second_lc.rate_err is not None) else None
if first_lc.gti is not None and second_lc.gti is not None:
new_gti = _join_gtis(np.array(first_lc.gti), np.array(second_lc.gti))
new_gti_start = new_gti[:, 0]
new_gti_stop = new_gti[:, 1]
else:
new_gti_start = None
new_gti_stop = None
lc_cls = type(self)
lc_new = lc_cls(
path=self.path,
bin_lo=new_bin_lo,
bin_hi=new_bin_hi,
bin_width=new_bin_width,
counts=new_counts,
rate=new_rate,
counts_err=new_counts_err,
rate_err=new_rate_err,
dt=self.dt,
tstart=float(new_bin_lo[0]),
gti_start=new_gti_start,
gti_stop=new_gti_stop,
mjdref=self.mjdref,
timesys=self.timesys,
exposure=self.exposure,
header=self.header,
meta=self.meta,
headers_dump=self.headers_dump,
columns=self.columns,
)
lc_new.binning = 'variable' if (new_bin_width.size > 1 and not np.allclose(new_bin_width, np.median(new_bin_width))) else 'uniform'
lc_new.tseg = float(new_bin_hi[-1] - new_bin_lo[0]) if len(new_bin_hi) > 0 else None
return lc_new
[docs]
def truncate(self, tmin: Optional[float] = None, tmax: Optional[float] = None) -> 'LightcurveData':
lo, hi, _ = self._resolve_bin_geometry()
tmin_val = tmin if tmin is not None else (lo[0] if len(lo) > 0 else 0.0)
tmax_val = tmax if tmax is not None else (hi[-1] if len(hi) > 0 else float(np.inf))
mask = (lo >= tmin_val) & (hi <= tmax_val)
lc_truncated = self.apply_mask(mask, inplace=False)
if lc_truncated.gti is not None:
gti_filtered = [
(max(float(s), tmin_val), min(float(e), tmax_val))
for s, e in lc_truncated.gti
if min(float(e), tmax_val) > max(float(s), tmin_val)
]
if gti_filtered:
lc_truncated.gti_start = np.array([s for s, _ in gti_filtered])
lc_truncated.gti_stop = np.array([e for _, e in gti_filtered])
else:
lc_truncated.gti_start = None
lc_truncated.gti_stop = None
return lc_truncated
[docs]
def sort(self, inplace: bool = False) -> 'LightcurveData':
lo, _, _ = self._resolve_bin_geometry()
sort_idx = np.argsort(lo)
mask = np.zeros(len(lo), dtype=bool)
mask[sort_idx] = True
return self.apply_mask(mask, inplace=inplace)
def _copy(self) -> 'LightcurveData':
import copy
return copy.copy(self)
[docs]
def slice(self, tmin: Optional[Union[float, Any]] = None, tmax: Optional[Union[float, Any]] = None) -> 'LightcurveData':
from .ops import slice_lightcurve
return cast(LightcurveData, slice_lightcurve(cast(Any, self), tmin=tmin, tmax=tmax))
[docs]
def rebin(
self,
binsize: float,
method: Literal['auto', 'sum', 'mean'] = 'auto',
*,
align_ref: Optional[float] = None,
empty_bin: Literal['zero', 'nan'] = 'zero',
) -> 'LightcurveData':
from .ops import rebin_lightcurve
return cast(LightcurveData, rebin_lightcurve(cast(Any, self), binsize=binsize, method=method, align_ref=align_ref, empty_bin=empty_bin))
def __sub__(
self,
other: 'LightcurveData',
*,
ratio: Optional[float] = None,
use_exposure_weighted_ratio: bool = True,
) -> 'LightcurveData':
from jinwu.core.datasets import netdata
result = netdata(cast(Any, self), cast(Any, other), ratio=ratio, use_exposure_weighted_ratio=use_exposure_weighted_ratio)
if getattr(result, "kind", None) != "lc":
raise TypeError(f"netdata must return LightcurveData-like object, got {type(result).__name__}")
return cast(LightcurveData, result)
def __add__(self, other: 'LightcurveData'):
from jinwu.core.datasets import LightcurveDataset
return LightcurveDataset(data=cast(Any, [self, other]))
[docs]
def plot(self, **kwargs):
from .plot import plot_lightcurve
return plot_lightcurve(self, **kwargs)
[docs]
class EventData(EventDataBase):
"""Concrete event data class with local behavior implementation."""
[docs]
def validate(self) -> ValidationReport: # type: ignore[override]
rpt = _validate_time_series_like(self)
colset = {c.upper() for c in self.columns}
if "TIME" not in colset:
rpt.add('ERROR', 'MISSING_COLUMN', "EVENTS missing TIME column")
rpt.ok = len(rpt.errors()) == 0
return rpt
@property
def n(self) -> int:
return len(self.time) if self.time is not None else 0
@property
def duration(self) -> Optional[float]:
if self.time.size:
return float(self.time.max() - self.time.min())
return None
@property
def absolute_time(self) -> np.ndarray:
tref = self.time_rel if self.time_rel.size > 0 else self.time
return np.asarray(tref, dtype=float) + float(self.timezero)
@property
def gti_exposure(self) -> Optional[float]:
if self.gti_start is None or self.gti_stop is None:
return None
return float(np.sum(self.gti_stop - self.gti_start))
[docs]
def get_time_object(self, index: Optional[int] = None) -> Optional[Any]:
if self.timezero_obj is None:
return None
from astropy.time import TimeDelta
absolute_times = self.absolute_time
dt = TimeDelta(absolute_times[index], format='sec') if index is not None else TimeDelta(absolute_times, format='sec')
return self.timezero_obj + dt
[docs]
def get_energy(self, rmf: Optional[RmfData] = None) -> Optional[np.ndarray]:
if getattr(self, 'energy', None) is not None:
return self.energy
if rmf is not None:
try:
dm = np.asarray(rmf.dense_matrix) if hasattr(rmf, 'dense_matrix') else np.asarray(rmf.rebuild_dense())
if getattr(rmf, 'energ_lo', None) is not None and getattr(rmf, 'energ_hi', None) is not None:
e_centers = 0.5 * (np.asarray(rmf.energ_lo) + np.asarray(rmf.energ_hi))
elif getattr(rmf, 'e_min', None) is not None and getattr(rmf, 'e_max', None) is not None:
e_centers = 0.5 * (np.asarray(rmf.e_min) + np.asarray(rmf.e_max))
else:
e_centers = None
if dm.ndim == 2:
if e_centers is not None and dm.shape[0] == e_centers.size and dm.shape[1] != e_centers.size:
dm_t = dm.T
else:
dm_t = dm
else:
dm_t = np.asarray(dm)
if getattr(self, 'channel', None) is not None:
ch_ev = np.asarray(self.channel).astype(int)
elif getattr(self, 'pi', None) is not None:
ch_ev = np.asarray(self.pi).astype(int)
else:
ch_ev = None
if (ch_ev is not None) and (e_centers is not None):
from ..ftools.rmf_mapping import map_channels_to_energy
mapped = map_channels_to_energy(dm_t, e_centers, ch_ev, method='expected')
self.energy = np.asarray(mapped, dtype=float)
return self.energy
except Exception:
pass
if getattr(self, 'ebounds', None) is not None:
ebounds_tuple = self.ebounds
if ebounds_tuple is not None and len(ebounds_tuple) == 3:
ch, e_lo, e_hi = ebounds_tuple
emid = 0.5 * (np.asarray(e_lo) + np.asarray(e_hi))
if getattr(self, 'channel', None) is not None:
ch_ev = np.asarray(self.channel)
elif getattr(self, 'pi', None) is not None:
ch_ev = np.asarray(self.pi)
else:
return None
cmap = {int(c): float(e) for c, e in zip(np.asarray(ch, int), emid)}
mapped = np.array([cmap.get(int(cc), np.nan) for cc in ch_ev], dtype=float)
self.energy = mapped
return self.energy
return None
[docs]
def plot(
self,
ax: Optional[Any] = None,
*,
bins: Union[int, Tuple[int, int]] = 300,
cmap: str = 'viridis',
title: Optional[str] = None,
invert_ra: bool = True,
show_grid: bool = True,
show_colorbar: bool = True,
**kwargs,
):
import matplotlib.pyplot as _plt
def _get_raw_col(name: Optional[str]) -> Optional[np.ndarray]:
if name is None or self.raw_columns is None:
return None
if name in self.raw_columns:
return np.asarray(self.raw_columns[name], float)
return None
def _find_name(cands: Tuple[str, ...]) -> Optional[str]:
if self.raw_columns is None:
return None
upper_to_orig = {str(k).upper(): str(k) for k in self.raw_columns.keys()}
for c in cands:
if c.upper() in upper_to_orig:
return upper_to_orig[c.upper()]
return None
def _to_ra_dec_from_xy(x: np.ndarray, y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
hdr = self.header or {}
try:
from astropy.wcs import WCS
w = WCS(hdr)
if getattr(w, 'has_celestial', False):
world = np.asarray(w.all_pix2world(np.column_stack([x, y]), 1), float)
ra_w = np.asarray(world[..., 0], float)
dec_w = np.asarray(world[..., 1], float)
if np.isfinite(ra_w).any() and np.isfinite(dec_w).any():
return ra_w, dec_w
except Exception:
pass
def _pick_float(keys: Tuple[str, ...], default: Optional[float] = None) -> Optional[float]:
for k in keys:
if k in hdr:
try:
return float(hdr[k])
except Exception:
continue
return default
crval1 = _pick_float(('TCRVL11', 'CRVAL1', 'TCRVL1'))
crval2 = _pick_float(('TCRVL12', 'CRVAL2', 'TCRVL2'))
crpix1 = _pick_float(('TCRPX11', 'CRPIX1', 'TCRPX1'), default=0.0)
crpix2 = _pick_float(('TCRPX12', 'CRPIX2', 'TCRPX2'), default=0.0)
cdelt1 = _pick_float(('TCDLT11', 'CDELT1', 'TCDLT1'))
cdelt2 = _pick_float(('TCDLT12', 'CDELT2', 'TCDLT2'))
if (crval1 is None) or (crval2 is None) or (cdelt1 is None) or (cdelt2 is None):
raise ValueError('Cannot build RA/DEC: missing RA/DEC columns and WCS keywords')
cos_dec0 = np.cos(np.deg2rad(crval2))
if np.abs(cos_dec0) < 1e-8:
cos_dec0 = 1.0
crpix1_f = float(0.0 if crpix1 is None else crpix1)
crpix2_f = float(0.0 if crpix2 is None else crpix2)
ra_l = crval1 + (x - crpix1_f) * float(cdelt1) / cos_dec0
dec_l = crval2 + (y - crpix2_f) * float(cdelt2)
return np.asarray(ra_l, float), np.asarray(dec_l, float)
ra_name = self.colmap.get('ra') if self.colmap is not None else None
dec_name = self.colmap.get('dec') if self.colmap is not None else None
ra = _get_raw_col(ra_name)
dec = _get_raw_col(dec_name)
if (ra is None) or (dec is None):
if ra is None:
ra = _get_raw_col(_find_name(('RA', 'RA_OBJ', 'RAX', 'RA_DEG')))
if dec is None:
dec = _get_raw_col(_find_name(('DEC', 'DEC_OBJ', 'DECX', 'DEC_DEG')))
if (ra is None) or (dec is None):
x = np.asarray(self.x, float) if self.x is not None else None
y = np.asarray(self.y, float) if self.y is not None else None
if (x is None) or (y is None):
key_x = self.colmap.get('x') if self.colmap is not None else None
key_y = self.colmap.get('y') if self.colmap is not None else None
x = _get_raw_col(key_x)
y = _get_raw_col(key_y)
if (x is None) or (y is None):
raise ValueError('No RA/DEC columns and no X/Y columns available for sky plotting')
ra, dec = _to_ra_dec_from_xy(np.asarray(x, float), np.asarray(y, float))
ra = np.asarray(ra, float)
dec = np.asarray(dec, float)
mask = np.isfinite(ra) & np.isfinite(dec)
if not np.any(mask):
raise ValueError('No finite RA/DEC values available for plotting')
ra = ra[mask]
dec = dec[mask]
ax = ax or _plt.gca()
assert ax is not None
h = ax.hist2d(ra, dec, bins=bins, cmap=cmap, **kwargs)
if show_colorbar:
_plt.colorbar(h[3], ax=ax, label='Counts')
ax.set_xlabel('Right ascension (deg)')
ax.set_ylabel('Declination (deg)')
if invert_ra:
ax.invert_xaxis()
if title is None:
fname = Path(str(getattr(self, 'path', ''))).name
title = f'{fname} sky image' if fname else 'Sky image'
ax.set_title(title)
if show_grid:
ax.grid(alpha=0.3, ls='--')
return ax
[docs]
def timescale(
self,
*,
background: Optional[Union['EventDataBase', str, Path]] = None,
alpha: Optional[float] = None,
) -> Any:
"""创建独立的 timescale 分析器。"""
return timescale(self, background=background, alpha=alpha)
[docs]
def slice(
self,
tmin: Optional[float] = None,
tmax: Optional[float] = None,
*,
pi_min: Optional[int] = None,
pi_max: Optional[int] = None,
ch_min: Optional[int] = None,
ch_max: Optional[int] = None,
) -> 'EventData':
from .ops import slice_events
return cast(EventData, slice_events(cast(Any, self), tmin=tmin, tmax=tmax, pi_min=pi_min, pi_max=pi_max, ch_min=ch_min, ch_max=ch_max))
[docs]
def rebin(self, binsize: float, *, tmin: Optional[float] = None, tmax: Optional[float] = None) -> 'LightcurveData':
from .ops import rebin_events_to_lightcurve
return cast(LightcurveData, rebin_events_to_lightcurve(cast(Any, self), binsize=binsize, tmin=tmin, tmax=tmax))
[docs]
def xselect(self) -> 'Any':
from .xselect import XSelectSession
if self._xselect_session is None:
self._xselect_session = cast(Any, XSelectSession(cast(Any, self)))
return self._xselect_session
[docs]
def filter_time(self, tmin: Optional[float] = None, tmax: Optional[float] = None) -> 'EventData':
session = self.xselect()
session.apply_time(tmin=tmin, tmax=tmax)
cur = getattr(session, 'current', None)
return cur if cur is not None else self
[docs]
def filter_region(self, region) -> 'EventData':
session = self.xselect()
session.apply_region(region)
cur = getattr(session, 'current', None)
return cur if cur is not None else self
[docs]
def filter_energy(self, pi_min: Optional[int] = None, pi_max: Optional[int] = None) -> 'EventData':
session = self.xselect()
session.apply_energy(pi_min=pi_min, pi_max=pi_max)
cur = getattr(session, 'current', None)
return cur if cur is not None else self
def _current_for_products(self) -> 'EventData':
sess = getattr(self, '_xselect_session', None)
if sess is not None:
cur = getattr(sess, 'current', None)
if cur is not None:
return cur
return self
[docs]
def save(self, outpath: str | Path, kind: str = 'evt', overwrite: bool = False, **kwargs) -> Path:
src = self._current_for_products()
if src is not self:
return src.save(outpath, kind=kind, overwrite=overwrite, **kwargs)
outp = Path(outpath)
if kind == 'evt':
from astropy.io import fits
from astropy.table import Table
if getattr(self, 'raw_columns', None):
try:
t = Table(self.raw_columns)
except Exception:
cols = {'TIME': self.time}
if getattr(self, 'x', None) is not None:
cols['X'] = getattr(self, 'x')
if getattr(self, 'y', None) is not None:
cols['Y'] = getattr(self, 'y')
if getattr(self, 'energy', None) is not None:
cols['ENERGY'] = getattr(self, 'energy')
t = Table(cols)
else:
cols = {'TIME': self.time}
if getattr(self, 'x', None) is not None:
cols['X'] = getattr(self, 'x')
if getattr(self, 'y', None) is not None:
cols['Y'] = getattr(self, 'y')
if getattr(self, 'energy', None) is not None:
cols['ENERGY'] = getattr(self, 'energy')
t = Table(cols)
t.write(outp, format='fits', overwrite=overwrite)
try:
with fits.open(outp, mode='update') as hdul:
if getattr(self, 'headers_dump', None) is not None and isinstance(self.headers_dump, FitsHeaderDump):
if getattr(self.headers_dump, 'primary', None) is not None:
for k, v in (self.headers_dump.primary or {}).items():
try:
prih = cast(Any, hdul[0])
prih.header[k] = v
except Exception:
continue
if getattr(self, 'header', None) is not None and len(hdul) > 1:
tbl = cast(Any, hdul[1])
for k, v in (self.header or {}).items():
if k in ('TFIELDS', 'TTYPE1', 'TFORM1', 'XTENSION', 'BITPIX', 'NAXIS'):
continue
try:
tbl.header[k] = v
except Exception:
continue
hdul.flush()
except Exception:
pass
return outp
if kind == 'lc':
lc = self.extract_curve(**kwargs)
from . import xselect
return xselect.write_curve(cast(Any, lc), outp, overwrite=overwrite)
if kind == 'pha':
pha = self.extract_spectrum(**kwargs)
from . import xselect
return xselect.write_pha(cast(Any, pha), outp, overwrite=overwrite)
if kind == 'img':
img, xedges, yedges = self.extract_image(**kwargs)
from . import xselect
return xselect.write_image(img, xedges, yedges, outp, overwrite=overwrite)
raise ValueError('unknown kind: ' + str(kind))
[docs]
def clear_region(self, *, use_original: bool = True) -> 'EventData':
sess = getattr(self, '_xselect_session', None)
if sess is not None:
sess.clear_region()
return self._current_for_products()
src = None
if use_original and getattr(self, '_original_events', None) is not None:
src = getattr(self, '_original_events')
elif getattr(self, 'path', None) is not None:
from .io import read_evt
src = read_evt(self.path)
else:
raise ValueError('No original events available to clear region')
tmin = None
tmax = None
time_arr = getattr(self, 'time', None)
if time_arr is not None and isinstance(time_arr, np.ndarray) and time_arr.size:
t = np.asarray(time_arr, dtype=float)
tmin = float(t.min())
tmax = float(t.max())
pi_min = None
pi_max = None
if getattr(self, 'pi', None) is not None:
arr = np.asarray(self.pi, dtype=int)
if arr.size:
pi_min = int(arr.min())
pi_max = int(arr.max())
elif getattr(self, 'channel', None) is not None:
arr = np.asarray(self.channel, dtype=int)
if arr.size:
pi_min = int(arr.min())
pi_max = int(arr.max())
from . import xselect as _xsel
return cast(EventData, _xsel.select_events(cast(Any, src), tmin=tmin, tmax=tmax, pi_min=pi_min, pi_max=pi_max))
[docs]
def clear_time(self, *, use_original: bool = True) -> 'EventData':
sess = getattr(self, '_xselect_session', None)
if sess is not None:
sess.clear_time()
return self._current_for_products()
src = None
if use_original and getattr(self, '_original_events', None) is not None:
src = getattr(self, '_original_events')
elif getattr(self, 'path', None) is not None:
from .io import read_evt
src = read_evt(self.path)
else:
raise ValueError('No original events available to clear time')
pi_min = None
pi_max = None
if getattr(self, 'pi', None) is not None:
arr = np.asarray(self.pi, dtype=int)
if arr.size:
pi_min = int(arr.min())
pi_max = int(arr.max())
elif getattr(self, 'channel', None) is not None:
arr = np.asarray(self.channel, dtype=int)
if arr.size:
pi_min = int(arr.min())
pi_max = int(arr.max())
from . import xselect as _xsel
return cast(EventData, _xsel.select_events(cast(Any, src), region=None, pi_min=pi_min, pi_max=pi_max))
[docs]
def clear_energy(self, *, use_original: bool = True) -> 'EventData':
sess = getattr(self, '_xselect_session', None)
if sess is not None:
sess.clear_energy()
return self._current_for_products()
src = None
if use_original and getattr(self, '_original_events', None) is not None:
src = getattr(self, '_original_events')
elif getattr(self, 'path', None) is not None:
from .io import read_evt
src = read_evt(self.path)
else:
raise ValueError('No original events available to clear energy')
tmin = None
tmax = None
time_arr = getattr(self, 'time', None)
if time_arr is not None and isinstance(time_arr, np.ndarray) and time_arr.size:
t = np.asarray(time_arr, dtype=float)
tmin = float(t.min())
tmax = float(t.max())
from . import xselect as _xsel
return cast(EventData, _xsel.select_events(cast(Any, src), tmin=tmin, tmax=tmax, region=None))
[docs]
def clear_all(self, *, use_original: bool = True) -> 'EventData':
sess = getattr(self, '_xselect_session', None)
if sess is not None:
sess.clear_all()
return self._current_for_products()
if use_original and getattr(self, '_original_events', None) is not None:
return getattr(self, '_original_events')
if getattr(self, 'path', None) is not None:
from .io import read_evt
return read_evt(self.path)
raise ValueError('No original events available to clear all')
EventLike = Union[EventDataBase, str, Path]
[docs]
class timescale:
"""事件时标分析器(独立于 EventData.plot)。
负责调用 `ops.txx` 计算时标,并通过 `plot.plot_event_txx` 可视化。
"""
def __init__(
self,
source: EventLike,
*,
background: Optional[EventLike] = None,
alpha: Optional[float] = None,
timezero: Optional[Any] = None,
) -> None:
self.source = source
self.background = background
self.alpha = alpha
self.timezero = timezero
self.result: Optional[Dict[str, Any]] = None
self.summary_text: Optional[str] = None
_COMMON_RESULT_KEYS = {
't100',
't100_err',
't100_value',
't100_tstart',
't100_tstop',
't90',
't90_err',
't90_value',
't90_tstart',
't90_tstop',
't50',
't50_err',
't50_value',
't50_tstart',
't50_tstop',
}
_VALUE_ERR_PACK_KEYS = {'t50', 't90', 't100'}
@staticmethod
def _as_err_low_high(err_any: Any) -> Dict[str, float]:
err = np.asarray(err_any, dtype=float).reshape(-1)
if err.size >= 2 and np.all(np.isfinite(err[:2])):
return {
'low': float(abs(err[0])),
'high': float(abs(err[1])),
}
if err.size == 1 and np.isfinite(err[0]):
e = float(abs(err[0]))
return {
'low': e,
'high': e,
}
return {
'low': float('nan'),
'high': float('nan'),
}
def __getattr__(self, name: str) -> Any:
# 允许直接访问 result 字典中的字段。
# 对 t50/t90/t100,默认返回 value + err_low/err_high 的组合输出。
res = self.result
if isinstance(res, dict):
if name.endswith('_value'):
base = name[:-6]
if base in self._VALUE_ERR_PACK_KEYS and (base in res):
return res[base]
if name in self._VALUE_ERR_PACK_KEYS and (name in res):
err_pair = self._as_err_low_high(
res.get(f'{name}_err', np.asarray([], dtype=float))
)
return {
'value': res[name],
'err_low': err_pair['low'],
'err_high': err_pair['high'],
'err': {
'low': err_pair['low'],
'high': err_pair['high'],
},
}
if name in res:
return res[name]
if name in self._COMMON_RESULT_KEYS:
raise AttributeError(
f"timescale.{name} 需要先运行 compute() 或 plot() 后再访问"
)
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
def __dir__(self) -> list[str]:
# 在自动补全中展示当前 result 可直接访问的字段。
names = set(super().__dir__())
res = self.result
if isinstance(res, dict):
names.update(
k for k in res.keys() if isinstance(k, str) and k.isidentifier()
)
names.update(
f'{k}_value'
for k in self._VALUE_ERR_PACK_KEYS
if k in res
)
return sorted(names)
@staticmethod
def _safe_float(val: Any, default: float = np.nan) -> float:
try:
out = float(val)
if np.isfinite(out):
return out
except Exception:
pass
return float(default)
def _resolve_timezero_seconds(self) -> float:
tz = self.timezero
if tz is None:
return np.nan
tz_num = self._safe_float(tz, np.nan)
if np.isfinite(tz_num):
return tz_num
# 尝试将 astropy.time.Time 转换到事件时间标尺(秒)。
if hasattr(tz, 'mjd'):
src = self.source
src_tz = self._safe_float(getattr(src, 'timezero', np.nan), np.nan)
src_tz_obj = getattr(src, 'timezero_obj', None)
# 当 source 是路径时,读取事件头以获得 timezero/timezero_obj 作为锚点。
if ((not np.isfinite(src_tz)) or (src_tz_obj is None)) and isinstance(src, (str, Path)):
try:
from .io import read_evt
src_evt = read_evt(src)
src_tz = self._safe_float(getattr(src_evt, 'timezero', np.nan), np.nan)
src_tz_obj = getattr(src_evt, 'timezero_obj', None)
except Exception:
pass
if np.isfinite(src_tz) and (src_tz_obj is not None):
# 优先走 Time 差值到秒,避免不同 Time 格式直接读 mjd 带来的边界误差。
try:
dt = tz - src_tz_obj
if hasattr(dt, 'to_value'):
dt_sec = self._safe_float(dt.to_value('sec'), np.nan)
if np.isfinite(dt_sec):
return src_tz + dt_sec
if hasattr(dt, 'sec'):
dt_sec = self._safe_float(dt.sec, np.nan)
if np.isfinite(dt_sec):
return src_tz + dt_sec
except Exception:
pass
# 兜底:mjd 差值。
mjd_ref = self._safe_float(getattr(src_tz_obj, 'mjd', np.nan), np.nan)
mjd_val = self._safe_float(getattr(tz, 'mjd', np.nan), np.nan)
if np.isfinite(mjd_ref) and np.isfinite(mjd_val):
return src_tz + (mjd_val - mjd_ref) * 86400.0
return np.nan
@staticmethod
def _as_asym_text(value: float, err_any: Any) -> str:
err = np.asarray(err_any, dtype=float).reshape(-1)
if err.size >= 2 and np.all(np.isfinite(err[:2])):
lo = float(abs(err[0]))
hi = float(abs(err[1]))
return f'{value:.1f} (-{lo:.1f}/+{hi:.1f})'
if err.size == 1 and np.isfinite(err[0]):
e = float(abs(err[0]))
return f'{value:.1f} (-{e:.1f}/+{e:.1f})'
return f'{value:.1f}'
def _format_txx_line(self, res: Dict[str, Any], key: str, *, tz_ref: float) -> str:
val = self._safe_float(res.get(key, np.nan), np.nan)
if not np.isfinite(val):
return f'{key.upper()}: N/A'
main = self._as_asym_text(val, res.get(f'{key}_err', np.asarray([], dtype=float)))
line = f'{key.upper()}: {main} s'
t_start = self._safe_float(res.get(f'{key}_tstart', np.nan), np.nan)
t_stop = self._safe_float(res.get(f'{key}_tstop', np.nan), np.nan)
if np.isfinite(t_start) and np.isfinite(t_stop):
if np.isfinite(tz_ref):
rs = t_start - tz_ref
re = t_stop - tz_ref
line += f' [{rs:.1f}, {re:.1f}] s rel. timezero'
else:
line += f' [{t_start:.1f}, {t_stop:.1f}] s'
return line
def _build_summary_text(self, res: Dict[str, Any]) -> str:
tz_ref = self._resolve_timezero_seconds()
lines = ['Timescale summary']
tz_obj = self.timezero
tz_line: Optional[str] = None
if tz_obj is not None:
try:
tz_line = f'timezero = {tz_obj.utc.isot}'
except Exception as e:
raise ValueError(f'无法解析 timezero 对象为utc格式:{e}')
if tz_line is None and np.isfinite(tz_ref):
tz_line = f'timezero = {tz_ref:.6f} s (output uses relative boundaries)'
if tz_line is not None:
lines.append(tz_line)
lines.append(self._format_txx_line(res, 't100', tz_ref=tz_ref))
lines.append(self._format_txx_line(res, 't90', tz_ref=tz_ref))
lines.append(self._format_txx_line(res, 't50', tz_ref=tz_ref))
return '\n'.join(lines)
[docs]
def summary(self, result: Optional[Dict[str, Any]] = None) -> str:
res = self.result if result is None else result
if res is None:
raise ValueError('timescale.summary 需要先运行 compute,或传入 result')
text = self._build_summary_text(cast(Dict[str, Any], res))
self.summary_text = text
return text
[docs]
def compute(
self,
*,
background: Optional[EventLike] = None,
alpha: Optional[float] = None,
**txx_kwargs: Any,
) -> Dict[str, Any]:
"""调用 `ops.txx` 计算时标结果。"""
from .ops import txx as _txx
bkg = self.background if background is None else background
if bkg is None:
raise ValueError('timescale.compute 需要提供 background(EventData 或路径)')
alpha_eff = self.alpha if alpha is None else alpha
run_kwargs = dict(txx_kwargs)
if alpha_eff is not None:
run_kwargs['alpha'] = float(alpha_eff)
res = cast(Dict[str, Any], _txx(self.source, background=bkg, **run_kwargs))
# 记录用户用于计算的固定分箱宽度,便于绘图时恢复原始 light curve。
if 'evt_binsize' in run_kwargs:
try:
bs = float(run_kwargs['evt_binsize'])
if np.isfinite(bs) and bs > 0.0:
res.setdefault('evt_binsize', bs)
except Exception:
pass
self.background = bkg
if alpha_eff is not None:
self.alpha = float(alpha_eff)
self.result = res
self.summary_text = self._build_summary_text(res)
print(self.summary_text)
return res
[docs]
def plot(
self,
result: Optional[Dict[str, Any]] = None,
*,
srcname: Optional[str] = None,
title: Optional[str] = None,
figsize: Tuple[float, float] = (10.5, 8.2),
out: Optional[Union[str, Path]] = None,
binsize: Optional[float] = None,
background: Optional[EventLike] = None,
alpha: Optional[float] = None,
**txx_kwargs: Any,
):
"""绘制时标诊断图;若无 result 则内部先 compute。
参数
- binsize: 直方图分箱宽度(秒)。等价于传入 `evt_binsize` 给 `compute/ops.txx`。
这样可以在不显式调用 `compute` 的情况下直接 `plot(..., binsize=...)`。
"""
from .plot import plot_event_txx
run_kwargs = dict(txx_kwargs)
if 'timezero' in run_kwargs:
raise ValueError('请在 timescale 初始化时设置 timezero,不要在 plot 中传入')
if binsize is not None:
if 'evt_binsize' in run_kwargs:
raise ValueError('请只设置一个分箱参数:binsize 或 evt_binsize')
run_kwargs['evt_binsize'] = float(binsize)
if result is None:
if self.result is not None and background is None and alpha is None and not run_kwargs:
result_eff = self.result
else:
result_eff = self.compute(background=background, alpha=alpha, **run_kwargs)
else:
result_eff = cast(Dict[str, Any], result)
self.result = result_eff
bkg_eff = self.background if background is None else background
alpha_eff = self.alpha if alpha is None else alpha
fig, axes = plot_event_txx(
self.source,
result_eff,
background=bkg_eff,
alpha=alpha_eff,
srcname=srcname,
title=title,
figsize=figsize,
out=out,
timezero=self.timezero,
)
return fig, axes, result_eff
__all__ = [
'ArfData',
'RmfData',
'PhaData',
'LightcurveData',
'EventData',
'timescale',
]