'''
Date: 2025-05-30 17:43:59
LastEditors: Xinxiang Sun sunxx@nao.cas.cn
LastEditTime: 2025-11-07 14:35:02
LastEditTime: 2025-09-25 20:34:19
FilePath: /research/jinwu/src/jinwu/core/utils.py
'''
import math
import numpy as np
from typing import Union
import os
import gzip
import shutil
from pathlib import Path
def _require_xspec():
try:
import xspec # noqa: F401
except ModuleNotFoundError as exc:
raise ModuleNotFoundError(
"xspec is required for this functionality. Please install HEASOFT/pyxspec and ensure 'xspec' is importable."
) from exc
def generate_download_url(isot_time):
"""
根据给定的 isot (YYYY-MM-DDTHH:MM:SS) 时间生成 GBM poshist 文件的下载 URL。
参数:
- isot_time (str): ISOT 格式时间字符串,例如 "2024-01-01T12:00:00"
返回:
- url (str): 生成的 poshist 文件下载 URL
"""
# 解析时间
# 提取年份、月份、日期
year = isot_time.strftime('%y')
yr2 = isot_time.datetime.year
month = f"{isot_time.datetime.month:02d}" # 两位数格式
day = f"{isot_time.datetime.day:02d}"
# 生成文件名
filename = f"glg_poshist_all_{year}{month}{day}_v00.fit"
# 生成完整的下载路径
# https://heasarc.gsfc.nasa.gov/FTP/fermi/data/gbm/daily/2025/01/01/current/
# url = f"https://heasarc.gsfc.nasa.gov/FTP/fermi/data/gbm/daily/{yr2}/{isot_time.strftime('%m/%d/')}current/{filename}"
url = f"https://heasarc.gsfc.nasa.gov/FTP/fermi/data/gbm/daily/{yr2}/{isot_time.strftime('%m/%d/')}current"
return url
def extract_all_gz_recursive(root_path: Union[str, os.PathLike, Path],
remove_gz: bool = True,
verbose: bool = True) -> int:
"""
递归解压文件夹下所有 .gz 文件。
参数:
- root_path: 根目录路径(支持 str、pathlib.Path、os.PathLike)
- remove_gz: 解压后是否删除原 .gz 文件(默认 True)
- verbose: 是否打印解压日志(默认 True)
返回:
- 解压的文件数量
示例:
>>> extract_all_gz_recursive('/path/to/data')
>>> extract_all_gz_recursive(Path.home() / 'data')
>>> extract_all_gz_recursive('C:/data', remove_gz=False)
"""
# 统一转换为 pathlib.Path 对象
root = Path(root_path)
if not root.exists():
raise FileNotFoundError(f"路径不存在: {root}")
if not root.is_dir():
raise NotADirectoryError(f"不是目录: {root}")
count = 0
# 递归查找所有 .gz 文件
for gz_file in root.rglob('*.gz'):
try:
# 生成输出文件路径(移除 .gz 后缀)
output_file = gz_file.with_suffix('')
if verbose:
print(f"解压: {gz_file} -> {output_file}")
# 解压
with gzip.open(gz_file, 'rb') as f_in:
with open(output_file, 'wb') as f_out:
shutil.copyfileobj(f_in, f_out)
# 删除原 .gz 文件
if remove_gz:
gz_file.unlink()
if verbose:
print(f" 已删除: {gz_file}")
count += 1
except Exception as e:
print(f"❌ 错误处理 {gz_file}: {e}")
continue
if verbose:
print(f"\n✅ 总共解压 {count} 个文件")
return count
# 便捷别名
def gunzip(root_path, remove_gz=True, verbose=True):
"""gunzip 的别名,用法相同"""
return extract_all_gz_recursive(root_path, remove_gz, verbose)
# Legacy LF/redshift extrapolator moved out of core utilities.
def get_asym_err(param):
"""获取XSPEC参数的非对称误差"""
try:
array = np.array(param.error[:2]) - param.values[0]
return abs(array[0]), abs(array[1])
except Exception as e:
raise RuntimeError(f"Error getting asymmetric error for parameter {param.name}: {e}")
def flux_err_from_log10(lgflux, log_err_low, log_err_high):
"""从对数通量误差计算线性通量误差"""
try:
if lgflux is None or log_err_low is None or log_err_high is None:
return None, None
err_low = 10.0 ** lgflux - 10.0 ** (lgflux - float(log_err_low))
err_high = 10.0 ** (lgflux + float(log_err_high)) - 10.0 ** lgflux
return err_low, err_high
except Exception:
return None, None
def generate_xspec_result(model, spectrum) -> dict:
"""
根据XSPEC模型和光谱自动生成结果字典
参数:
model: XSPEC模型对象
spectrum: XSPEC光谱对象
返回:
包含模型参数、flux、rate等信息的字典
"""
_require_xspec()
import xspec
lines = []
result = {}
result['model'] = model.expression
result['parameters'] = {}
lines.append(f"Model: {model.expression}")
processed_params = set()
for comp_name in model.componentNames:
try:
comp = getattr(model, comp_name)
for param_name in comp.parameterNames:
param_key = f"{comp_name}.{param_name}"
if param_key in processed_params:
continue
processed_params.add(param_key)
param = getattr(comp, param_name)
param_val = param.values[0]
param_dict = {
'value': param_val,
'frozen': param.frozen
}
if not param.frozen:
err_lo, err_hi = get_asym_err(param)
param_dict['error_lo'] = err_lo
param_dict['error_hi'] = err_hi
lines.append(f"{comp_name}.{param_name}: {param_val:.4f} (-{err_lo:.4f}, +{err_hi:.4f})(1sigma error)")
else:
lines.append(f"{comp_name}.{param_name}: {param_val:.4f} (fixed)")
result['parameters'][param_key] = param_dict
except Exception as e:
raise RuntimeError(f"Error processing component {comp_name}: {e}")
emin = model.cflux.Emin.values[0] if hasattr(model, 'cflux') else None
emax = model.cflux.Emax.values[0] if hasattr(model, 'cflux') else None
xspec.AllModels.calcFlux(f"{emin} {emax}")
flux_erg = float(spectrum.flux[0])
flux_photons = float(spectrum.flux[3])
result['flux_abs'] = {
'erg_cm2_s': flux_erg,
'photons_cm2_s': flux_photons
}
lines.append(f"Absorbed Flux ({emin:.1f}-{emax:.1f} keV): {flux_erg:.4e} erg/cm²/s")
lines.append(f"Absorbed Photon Flux ({emin:.1f}-{emax:.1f} keV): {flux_photons:.4e} photons/cm²/s")
try:
rate = float(spectrum.rate[0])
rate_err = float(spectrum.rate[1]) if len(spectrum.rate) > 1 else None
except Exception as e:
raise RuntimeError(f"Error extracting rate: {e}")
result['rate'] = {
'value': rate,
'error': rate_err
}
if rate is not None:
if rate_err is not None:
lines.append(f"Rate: {rate:.4f} ± {rate_err:.4f} cts/s")
else:
lines.append(f"Rate: {rate:.4f} cts/s")
exposure = spectrum.exposure if hasattr(spectrum, 'exposure') else None
if rate is not None and rate > 0 and flux_erg > 0:
conv_factor = 10**model.cflux.lg10Flux.values[0] / rate
else:
conv_factor = None
photon_counts = rate * exposure if rate is not None and exposure is not None else None
result['conversion'] = {
'exposure_s': exposure,
'erg_per_count': conv_factor,
'counts': photon_counts
}
if exposure is not None:
lines.append(f"Exposure: {exposure:.1f} s")
if conv_factor is not None:
lines.append(f"Conversion factor: {conv_factor:.4e} erg/cm²/s per cts/s")
if photon_counts is not None:
lines.append(f"Total counts: {photon_counts:.2f} counts")
statistic = xspec.Fit.statistic
dof = xspec.Fit.dof
stat_method = xspec.Fit.statMethod
statdof = statistic / dof
lines.append(f"Stat/dof: {stat_method}={statistic:.2f}/{dof}={statdof:.2f}")
lines.append(f"Null hypothesis probability: {xspec.Fit.nullhyp:.4f}")
result['statistics'] = {
'method': stat_method,
'value': statistic,
'dof': dof,
'reduced': statdof,
'null_hypothesis_probability': xspec.Fit.nullhyp
}
result['text'] = "\n".join(lines)
return result
def _parse_nhtot_response(html, coord_str=""):
"""Parse the ASCII table from the nhtot HTML response.
Pure parsing function — no network I/O. Separated for testability.
Parameters
----------
html : str
Raw HTML response body from donhtot.php.
coord_str : str
Coordinate string for error messages only.
Returns
-------
dict
Always contains ``ok`` (bool). On success (ok=True), also contains
ra, dec, ebv_mean/weighted, nhi_mean/weighted, nh2_mean/weighted,
nhtot_mean/weighted. On failure (ok=False), contains ``error`` (str)
and all value fields set to None.
"""
import re
_NONE = {
'ra': None, 'dec': None,
'ebv_mean': None, 'ebv_weighted': None,
'nhi_mean': None, 'nhi_weighted': None,
'nh2_mean': None, 'nh2_weighted': None,
'nhtot_mean': None, 'nhtot_weighted': None,
}
def _num(s):
try:
return float(s)
except (ValueError, TypeError):
return None
lines = html.split('\n')
# Locate data row between the first and second +===+ separator rows.
# The data row is the first line after the opening === that contains
# both a pipe and a celestial coordinate prefix (J or B).
data_line = None
after_header = False
for line in lines:
if '+===' in line:
if not after_header:
after_header = True # opening === row, data follows
else:
break # closing === row, stop
elif after_header and '|' in line:
stripped = line.strip()
# Confirm this looks like a data row: starts with J/B prefix
# (possibly after an optional leading | from old-style tables)
if stripped and re.match(r'^\|?\s*[JB]\s', stripped):
data_line = line
break
if data_line is None:
return {**_NONE, 'ok': False,
'error': f'No data row found for {coord_str}'}
# Split on | and drop empty fields (fix: leading/trailing | robustness)
fields = [f.strip() for f in data_line.split('|') if f.strip()]
if len(fields) < 9:
return {**_NONE, 'ok': False,
'error': f'Expected >=9 fields, got {len(fields)} for {coord_str}'}
# Parse position from fields[0] (fix: strict prefix removal)
pos = fields[0]
ra_str, dec_str = None, None
if pos:
pos_clean = pos
for prefix in ('J ', 'B '):
if pos_clean.startswith(prefix):
pos_clean = pos_clean[len(prefix):]
break
parts = pos_clean.split(',')
if len(parts) == 2:
ra_str, dec_str = parts[0].strip(), parts[1].strip()
return {
'ok': True,
'ra': ra_str,
'dec': dec_str,
'ebv_mean': _num(fields[1]),
'ebv_weighted': _num(fields[2]),
'nhi_mean': _num(fields[3]),
'nhi_weighted': _num(fields[4]),
'nh2_mean': _num(fields[5]),
'nh2_weighted': _num(fields[6]),
'nhtot_mean': _num(fields[7]),
'nhtot_weighted': _num(fields[8]),
}
def nhtot(ra, dec, equinox=2000):
"""
Query the Swift UKSSDC nhtot service for Galactic hydrogen column density
using the method of Willingale et al. (2013, MNRAS, 431, 394).
Parameters
----------
ra : float or str
Right Ascension in decimal degrees (e.g. 159.386) or sexagesimal
(e.g. "10:37:32.6").
dec : float or str
Declination in decimal degrees (e.g. 56.171) or sexagesimal
(e.g. "+56:10:15.6").
equinox : int
Equinox: 2000 for J2000, 1950 for B1950.
Returns
-------
dict
Keys: ok (bool), ra (str), dec (str), ebv_mean (float),
ebv_weighted (float), nhi_mean (float), nhi_weighted (float),
nh2_mean (float), nh2_weighted (float), nhtot_mean (float),
nhtot_weighted (float). All NH in atoms cm⁻², E(B-V) in mag.
On failure, ok=False, error (str) set, and all value fields None.
On success, ok=True.
Examples
--------
>>> result = nhtot(159.386, 56.171)
>>> print(result['nhtot_weighted'])
5.26e+19
>>> result = nhtot("10:30:00", "+50:00:00")
"""
import urllib.request
import urllib.parse
coord_str = f"{ra} {dec}"
params = urllib.parse.urlencode({
'Coords': coord_str,
'equinox': str(equinox),
'ascii': '1',
'jsOn': '1',
'obname': '',
'MAX_FILE_SIZE': '1000000',
}).encode('ascii')
url = "https://www.swift.ac.uk/analysis/nhtot/donhtot.php"
try:
req = urllib.request.Request(url, data=params)
with urllib.request.urlopen(req, timeout=30) as resp:
html = resp.read().decode('utf-8', errors='replace')
except Exception as e:
return {
'ok': False, 'error': str(e),
'ra': None, 'dec': None,
'ebv_mean': None, 'ebv_weighted': None,
'nhi_mean': None, 'nhi_weighted': None,
'nh2_mean': None, 'nh2_weighted': None,
'nhtot_mean': None, 'nhtot_weighted': None,
}
result = _parse_nhtot_response(html, coord_str)
if not result.get('ok'):
print(f"nhtot: parse failed for {coord_str}: {result.get('error', 'unknown')}")
return result
class HydroDynamics:
"""经典/相对论流体力学辅助类"""
@classmethod
def show_shock_jump_conditions(cls):
"""
展示流体力学的激波跳变条件(Rankine-Hugoniot conditions)
"""
from IPython.display import display, Math
display(Math(r"\text{激波跳变条件(Rankine-Hugoniot conditions):}"))
eqs = [
r"\frac{\rho_2}{\rho_1} = \frac{v_1}{v_2} = \frac{(\hat{\gamma}+1)M_1^2}{(\hat{\gamma}-1)M_1^2+2}",
r"\frac{p_2}{p_1} = \frac{2\hat{\gamma} M_1^2 - \hat{\gamma} + 1}{\hat{\gamma} + 1}",
r"\frac{T_2}{T_1} = \frac{p_2 \rho_1}{p_1 \rho_2} = \frac{(2\hat{\gamma} M_1^2 - \hat{\gamma} + 1)[(\hat{\gamma}-1)M_1^2+2]}{(\hat{\gamma}+1)^2 M_1^2}"
]
for eq in eqs:
display(Math(eq))
class SFH:
def __init__(self):
"""
星系形成历史(SFH)类,用于处理和分析星系的形成和演化历史。
"""
pass
# NOTE for future AI / maintainers:
# RedshiftExtrapolator was moved to jinwu.lf.legacy_redshift.
# Do NOT re-add a top-level import of anything under jinwu.lf here —
# it creates a circular import (core.utils -> lf -> detectability -> core.utils).
# Users who need the legacy class should import it directly:
# from jinwu.lf.legacy_redshift import RedshiftExtrapolator
# ======================================================================
# Li & Ma SNR 和触发判断工具
# ======================================================================
[docs]
def li_ma_snr(n_on: float, n_off: float, alpha: float, *, signed: bool = True) -> float:
"""Compute Li & Ma significance (Eq. 17 in Li & Ma 1983).
Parameters
----------
n_on : float
Total counts in ON region.
n_off : float
Total counts in OFF region (reference background).
alpha : float
Exposure/area scaling: alpha = (A_on/A_off) * (t_on/t_off).
signed : bool
If ``True`` (default), return a signed significance — positive
for source excess, negative for background deficit. Pass
``signed=False`` to get the unsigned magnitude (deviation from
background hypothesis regardless of direction).
Returns
-------
float
Li & Ma significance. Returns 0.0 for degenerate inputs.
"""
if n_on <= 0 and n_off <= 0:
return 0.0
if alpha <= 0:
return 0.0
n_on = float(max(n_on, 0.0))
n_off = float(max(n_off, 0.0))
# Compute the unsigned Li & Ma magnitude S = sqrt(2 ln L)
if n_on == 0.0 and n_off == 0.0:
s = 0.0
elif n_on == 0.0:
# n_off > 0: term1 -> 0, term2 = n_off * ln(1+alpha)
s = float(np.sqrt(2.0 * n_off * np.log(1.0 + alpha)))
elif n_off == 0.0:
# n_on > 0: term1 = n_on * ln((1+alpha)/alpha), term2 -> 0
s = float(np.sqrt(2.0 * n_on * np.log((1.0 + alpha) / alpha)))
else:
term1 = n_on * np.log(((1.0 + alpha) / alpha) * (n_on / (n_on + n_off)))
term2 = n_off * np.log((1.0 + alpha) * (n_off / (n_on + n_off)))
val = 2.0 * (term1 + term2)
s = float(np.sqrt(max(val, 0.0)))
if signed:
return math.copysign(s, n_on - alpha * n_off)
return s
from dataclasses import dataclass as _dataclass
from typing import Optional as _Optional, Tuple as _Tuple, Literal as _Literal, Union as _Union
[docs]
@_dataclass
class BackgroundSimple:
"""Minimal background configuration for Li & Ma significance.
Parameters
----------
area_ratio : float
A_on / A_off.
t_off_ref : float
Reference OFF exposure (seconds).
n_off_ref : float
Total OFF counts corresponding to t_off_ref.
"""
area_ratio: float
t_off_ref: float
n_off_ref: float
[docs]
def alpha(self, t_on: float) -> float:
return float(self.area_ratio) * (float(t_on) / float(self.t_off_ref))
[docs]
class TriggerDecider:
"""Decide triggerability from a binned counts lightcurve or event times.
Core checks
-----------
- sliding_window(window=1200): scan max Li&Ma SNR over all windows.
- head_window(window=1200): Li&Ma SNR of the first window only.
- cumulative_from_t0(target=7): grow cumulatively from T0.
Inputs
------
time : 1D array of bin left edges (monotonic increasing).
counts : 1D array of ON-region counts per bin (non-negative).
dt : float bin width in seconds (assumed constant).
bg : BackgroundSimple with (n_off_ref, t_off_ref, area_ratio).
"""
def __init__(
self,
time: np.ndarray,
counts: np.ndarray,
dt: float,
bg: BackgroundSimple,
) -> None:
time = np.asarray(time, dtype=float)
counts = np.asarray(counts, dtype=float)
if time.ndim != 1 or counts.ndim != 1:
raise ValueError("time and counts must be 1D arrays")
if time.size != counts.size:
raise ValueError("time and counts must have the same length")
if dt <= 0:
raise ValueError("dt must be positive")
self.time = time
self.counts = counts
self.dt = float(dt)
self.bg = bg
self._cum = np.cumsum(self.counts)
[docs]
@classmethod
def from_counts(
cls,
time: np.ndarray,
counts: np.ndarray,
dt: _Optional[float],
bg: BackgroundSimple,
) -> "TriggerDecider":
time = np.asarray(time, dtype=float)
counts = np.asarray(counts, dtype=float)
if dt is None:
if time.size < 2:
raise ValueError("Need dt or at least two time points to infer dt")
dt = float(np.median(np.diff(time)))
return cls(time=time, counts=counts, dt=dt, bg=bg)
[docs]
@classmethod
def from_events(
cls,
events: np.ndarray,
*,
dt: float,
bg: BackgroundSimple,
t_start: _Optional[float] = None,
t_end: _Optional[float] = None,
) -> "TriggerDecider":
events = np.asarray(events, dtype=float)
if events.ndim != 1:
raise ValueError("events must be 1D array of times")
if events.size == 0:
raise ValueError("events is empty")
if dt <= 0:
raise ValueError("dt must be positive")
if t_start is None:
t_start = float(np.min(events))
if t_end is None:
t_end = float(np.max(events)) + float(dt)
nbins = int(np.ceil((t_end - t_start) / float(dt)))
edges = t_start + np.arange(nbins + 1, dtype=float) * float(dt)
counts, _ = np.histogram(events, bins=edges)
time = edges[:-1]
return cls(time=time, counts=counts.astype(float), dt=float(dt), bg=bg)
def _counts_in(self, left: float, right: float) -> float:
i0 = int(np.searchsorted(self.time, left, side="left"))
i1 = int(np.searchsorted(self.time, right, side="left"))
if i1 <= i0:
return 0.0
return float(self._cum[i1 - 1] - (self._cum[i0 - 1] if i0 > 0 else 0.0))
def _snr_window(
self, left: float, right: float, n_off_ref: _Optional[float] = None,
) -> float:
n_on = self._counts_in(left, right)
t_on = max(0.0, float(right - left))
if t_on <= 0:
return 0.0
alpha = self.bg.alpha(t_on)
n_off = float(self.bg.n_off_ref if n_off_ref is None else n_off_ref)
return li_ma_snr(n_on=n_on, n_off=n_off, alpha=alpha)
[docs]
def sliding_window(
self, *, window: float = 1200.0, step: _Optional[float] = None,
target: float = 7.0,
) -> _Tuple[bool, dict]:
if window <= 0:
raise ValueError("window must be positive")
if step is None:
step = self.dt
step = float(step)
if step <= 0:
raise ValueError("step must be positive")
t0 = float(self.time[0])
tN = float(self.time[0] + self.counts.size * self.dt)
starts = np.arange(t0, max(t0, tN - window) + 1e-12, step, dtype=float)
max_snr = 0.0
best = (t0, t0 + window)
for s in starts:
snr = self._snr_window(s, s + window)
if snr > max_snr:
max_snr = snr
best = (s, s + window)
return bool(max_snr >= target), {"max_snr": max_snr, "best_window": best}
[docs]
def head_window(
self, *, window: float = 1200.0, target: float = 7.0,
) -> _Tuple[bool, dict]:
left = float(self.time[0])
right = left + float(window)
snr = self._snr_window(left, right)
return bool(snr >= target), {"snr": snr, "window": (left, right)}
def _find_t0(
self, mode: _Literal["first_nonzero", "first_time"] = "first_nonzero",
) -> float:
if mode == "first_time":
return float(self.time[0])
idx = int(np.argmax(self.counts > 0)) if np.any(self.counts > 0) else 0
return float(self.time[idx])
[docs]
def cumulative_from_t0(
self,
*,
target: float = 7.0,
t0_mode: _Literal["first_nonzero", "first_time"] = "first_nonzero",
max_window: _Optional[float] = 1200,
) -> _Tuple[bool, dict]:
T0 = self._find_t0(mode=t0_mode)
t_end = float(self.time[0] + self.counts.size * self.dt)
if max_window is not None:
t_end = min(t_end, T0 + float(max_window))
i0 = int(np.searchsorted(self.time, T0, side="left"))
i1 = int(np.searchsorted(self.time, t_end, side="left"))
if i1 <= i0:
return False, {"T0": T0, "t_reach": None, "max_snr": 0.0}
csum = np.cumsum(self.counts[i0:i1])
max_snr = 0.0
t_reach: _Optional[float] = None
for k in range(1, csum.size + 1):
t_on = k * self.dt
alpha = self.bg.alpha(t_on)
snr = li_ma_snr(n_on=float(csum[k - 1]), n_off=float(self.bg.n_off_ref), alpha=alpha)
if snr > max_snr:
max_snr = snr
if snr >= float(target) and t_reach is None:
t_reach = T0 + t_on
break
return bool(t_reach is not None), {"T0": T0, "t_reach": t_reach, "max_snr": max_snr}
[docs]
def decide(
self,
*,
window: float = 1200.0,
target: float = 7.0,
step: _Optional[float] = None,
t0_mode: _Literal["first_nonzero", "first_time"] = "first_nonzero",
) -> dict:
slid_ok, slid_stat = self.sliding_window(window=window, step=step, target=target)
if slid_ok:
return {"triggered": True, "method": "sliding", **slid_stat}
head_ok, head_stat = self.head_window(window=window, target=target)
if head_ok:
return {"triggered": True, "method": "head", **head_stat}
cum_ok, cum_stat = self.cumulative_from_t0(target=target, t0_mode=t0_mode, max_window=None)
return {"triggered": bool(cum_ok), "method": "cumulative", **cum_stat}
[docs]
class LightcurveSNREvaluator:
"""Evaluate whether a binned lightcurve can reach a target SNR after T0.
T0 is detected via Bayesian Blocks with per-block Li & Ma SNR ≥ 3.
Supports a fast expected-value mode and an MC mode with Poisson
fluctuations for ON and OFF counts.
Typical usage
-------------
>>> bg = BackgroundPrior(n_off_prior=1200, t_off=100000.0, area_ratio=1/12)
>>> ev = LightcurveSNREvaluator.from_counts(
... time=np.arange(0, 2000.0, 0.5),
... counts=np.random.poisson(0.1, 4000),
... dt=0.5,
... background=bg,
... )
>>> ok, stats = ev.reaches_snr(target=7.0, window=1200.0, mode="fast")
"""
def __init__(
self,
time: np.ndarray,
counts: np.ndarray,
dt: float,
background: _Union["_BackgroundPrior", "_BackgroundCountsPosterior"],
off_exposure_ref: _Optional[float] = None,
) -> None:
from jinwu.background.backprior import (
BackgroundPrior as _BackgroundPrior,
BackgroundCountsPosterior as _BackgroundCountsPosterior,
)
if time.ndim != 1 or counts.ndim != 1:
raise ValueError("time and counts must be 1D arrays")
if time.size != counts.size:
raise ValueError("time and counts must have the same length")
if dt <= 0:
raise ValueError("dt must be positive")
self.time = np.asarray(time, dtype=float)
self.counts = np.asarray(counts, dtype=float)
self.dt = float(dt)
self._bg_prior: _Optional[_BackgroundPrior]
self._bg_post: _Optional[_BackgroundCountsPosterior]
if isinstance(background, _BackgroundCountsPosterior):
self._bg_prior = None
self._bg_post = background
self.area_ratio = float(background.area_ratio)
self.off_exposure_ref = float(off_exposure_ref) if off_exposure_ref is not None else 1_000_000.0
else:
self._bg_prior = background # type: ignore[assignment]
self._bg_post = None
self.area_ratio = float(background.area_ratio)
self.off_exposure_ref = float(getattr(background, "t_off", 1_000_000.0))
self._cum_counts = np.cumsum(self.counts)
[docs]
@classmethod
def from_counts(
cls,
time: np.ndarray,
counts: np.ndarray,
dt: _Optional[float] = None,
background: _Optional[_Union["_BackgroundPrior", "_BackgroundCountsPosterior"]] = None,
off_exposure_ref: _Optional[float] = None,
) -> "LightcurveSNREvaluator":
time = np.asarray(time, dtype=float)
counts = np.asarray(counts, dtype=float)
if dt is None:
if time.size < 2:
raise ValueError("Need dt or at least two time points to infer dt")
dt = float(np.median(np.diff(time)))
if background is None:
raise ValueError("background must be provided")
return cls(time=time, counts=counts, dt=dt, background=background, off_exposure_ref=off_exposure_ref)
[docs]
@classmethod
def from_npz(
cls,
npz_path: str,
background: _Union["_BackgroundPrior", "_BackgroundCountsPosterior"],
*,
time_key_primary: str = "time_series",
time_key_fallback: str = "raw_time_series",
counts_key_preferred: str = "corrected_counts_src",
net_key: str = "corrected_counts",
off_key: str = "corrected_counts_back",
raw_counts_key_fallback: str = "raw_corrected_counts",
dt: _Optional[float] = None,
off_exposure_ref: _Optional[float] = None,
verbose: bool = True,
) -> "LightcurveSNREvaluator":
data = np.load(npz_path)
if time_key_primary in data:
time = np.asarray(data[time_key_primary], dtype=float)
src_time_key = time_key_primary
elif time_key_fallback in data:
time = np.asarray(data[time_key_fallback], dtype=float)
src_time_key = time_key_fallback
else:
raise ValueError(
f"Cannot find time array in NPZ. "
f"Tried '{time_key_primary}' and '{time_key_fallback}'."
)
counts = None
used = None
if counts_key_preferred in data:
counts = np.asarray(data[counts_key_preferred], dtype=float)
used = counts_key_preferred
elif (net_key in data) and (off_key in data):
net = np.asarray(data[net_key], dtype=float)
off = np.asarray(data[off_key], dtype=float)
counts = net + float(background.area_ratio) * off
used = f"{net_key} + area_ratio*{off_key}"
elif raw_counts_key_fallback in data:
counts = np.asarray(data[raw_counts_key_fallback], dtype=float)
used = raw_counts_key_fallback
if verbose:
print(
f"[LightcurveSNREvaluator] Using '{raw_counts_key_fallback}' as ON counts.\n"
"If this is actually net counts, SNR will be conservative."
)
else:
raise ValueError(
"Cannot determine ON-region counts from NPZ. Provide one of: "
f"'{counts_key_preferred}', or both '{net_key}' & '{off_key}', "
f"or '{raw_counts_key_fallback}'."
)
if dt is None:
if time.size < 2:
raise ValueError("Need dt or at least two time samples to infer dt")
dt = float(np.median(np.diff(time)))
if verbose:
print(
f"[LightcurveSNREvaluator] Loaded time='{src_time_key}', "
f"counts='{used}', dt={dt:.6g}s"
)
return cls.from_counts(time=time, counts=counts, dt=dt, background=background, off_exposure_ref=off_exposure_ref)
def _block_snr(self, left: float, right: float, n_off: float) -> float:
i0 = int(np.searchsorted(self.time, left, side="left"))
i1 = int(np.searchsorted(self.time, right, side="left"))
if i1 <= i0:
return 0.0
n_on = float(self._cum_counts[i1 - 1] - (self._cum_counts[i0 - 1] if i0 > 0 else 0.0))
t_on = right - left
alpha = self._alpha(t_on)
return li_ma_snr(n_on=n_on, n_off=n_off, alpha=alpha)
def _alpha(self, t_on: float) -> float:
return float(self.area_ratio) * (float(t_on) / float(self.off_exposure_ref))
def _find_T0_by_blocks(
self,
snr_thr: float = 3.0,
n_off: _Optional[float] = None,
rng: _Optional[np.random.Generator] = None,
off_mode: _Literal["fixed", "poisson"] = "fixed",
) -> float:
from astropy.stats import bayesian_blocks
from jinwu.background.backprior import (
BackgroundPrior as _BackgroundPrior,
BackgroundCountsPosterior as _BackgroundCountsPosterior,
)
if n_off is None:
if self._bg_post is not None:
if off_mode == "fixed":
n_off = float(self._bg_post.expected_off(self.off_exposure_ref))
else:
rng = rng or np.random.default_rng()
lam_off = rng.gamma(shape=float(self._bg_post.a_total), scale=1.0 / float(self._bg_post.b))
n_off = float(rng.poisson(lam_off * float(self.off_exposure_ref)))
else:
prior = self._bg_prior
if off_mode == "fixed":
n_off = float(prior.n_off_prior) # type: ignore[union-attr]
else:
rng = rng or np.random.default_rng()
mu_off = float(prior.n_off_prior) / float(prior.t_off) # type: ignore[union-attr]
n_off = float(rng.poisson(mu_off * prior.t_off)) # type: ignore[union-attr]
edges = bayesian_blocks(self.time, self.counts, fitness="measures")
for i in range(len(edges) - 1):
left, right = float(edges[i]), float(edges[i + 1])
snr = self._block_snr(left, right, n_off=n_off)
if snr >= snr_thr:
return left
return float(self.time[0])
[docs]
def reaches_snr(
self,
target: float = 7.0,
window: float = 1200.0,
mode: _Literal["fast", "mc"] = "mc",
n_mc: int = 500,
rng: _Optional[np.random.Generator] = None,
t0_snr_thr: float = 3.0,
off_mode: _Literal["fixed", "poisson"] = "poisson",
) -> _Tuple[bool, dict]:
from jinwu.background.backprior import (
BackgroundPrior as _BackgroundPrior,
BackgroundCountsPosterior as _BackgroundCountsPosterior,
)
rng = rng or np.random.default_rng()
if mode == "fast":
if self._bg_post is not None:
n_off_exp = float(self._bg_post.expected_off(self.off_exposure_ref))
else:
n_off_exp = float(self._bg_prior.n_off_prior) # type: ignore[union-attr]
T0 = self._find_T0_by_blocks(snr_thr=t0_snr_thr, n_off=n_off_exp, off_mode="fixed")
t_start = T0
t_end = T0 + float(window)
i0 = int(np.searchsorted(self.time, t_start, side="left"))
i1 = int(np.searchsorted(self.time, t_end, side="left"))
if i1 <= i0:
return False, {"T0": T0, "max_snr": 0.0}
counts_win = self.counts[i0:i1]
csum = np.cumsum(counts_win)
max_snr = 0.0
for k in range(1, csum.size + 1):
t_on = k * self.dt
alpha = self._alpha(t_on)
n_on = float(csum[k - 1])
snr = li_ma_snr(n_on=n_on, n_off=float(n_off_exp), alpha=alpha)
if snr > max_snr:
max_snr = snr
ok = bool(max_snr >= target)
return ok, {"T0": T0, "max_snr": max_snr}
# MC mode
hits = 0
max_snrs = []
for _ in range(int(n_mc)):
if self._bg_post is not None:
if off_mode == "fixed":
n_off = float(self._bg_post.expected_off(self.off_exposure_ref))
else:
lam_off = rng.gamma(shape=float(self._bg_post.a_total), scale=1.0 / float(self._bg_post.b))
n_off = float(rng.poisson(lam_off * float(self.off_exposure_ref)))
else:
if off_mode == "fixed":
n_off = float(self._bg_prior.n_off_prior) # type: ignore[union-attr]
else:
mu_off = float(self._bg_prior.n_off_prior) / float(self._bg_prior.t_off) # type: ignore[union-attr]
n_off = float(rng.poisson(mu_off * self._bg_prior.t_off)) # type: ignore[union-attr]
T0 = self._find_T0_by_blocks(snr_thr=t0_snr_thr, n_off=n_off, rng=rng, off_mode=off_mode)
t_start = T0
t_end = T0 + float(window)
i0 = int(np.searchsorted(self.time, t_start, side="left"))
i1 = int(np.searchsorted(self.time, t_end, side="left"))
if i1 <= i0:
max_snrs.append(0.0)
continue
bins = slice(i0, i1)
if self._bg_post is not None:
lam_off = rng.gamma(shape=float(self._bg_post.a_total), scale=1.0 / float(self._bg_post.b))
mu_bkg_bin = float(lam_off) * float(self.area_ratio) * float(self.dt)
lam_on_obs = np.clip(self.counts[bins], 0.0, None)
mu_src_bin = np.clip(lam_on_obs - mu_bkg_bin, 0.0, None)
n_src_bins = rng.poisson(mu_src_bin)
n_bkg_bins = rng.poisson(mu_bkg_bin, size=n_src_bins.size)
n_on_bins = n_src_bins + n_bkg_bins
else:
lam_on = np.clip(self.counts[bins], 0.0, None)
n_on_bins = rng.poisson(lam_on)
csum = np.cumsum(n_on_bins)
max_snr = 0.0
for k in range(1, csum.size + 1):
t_on = k * self.dt
alpha = self._alpha(t_on)
snr = li_ma_snr(n_on=float(csum[k - 1]), n_off=n_off, alpha=alpha)
if snr > max_snr:
max_snr = snr
max_snrs.append(float(max_snr))
hits += int(max_snr >= target)
prob = hits / float(n_mc)
return bool(prob >= 0.95), {"prob": prob, "max_snrs": np.asarray(max_snrs)}