Source code for sunkit_spex.models.models
"""Module for generic mathematical models."""
import numpy as np
import astropy.units as u
from astropy.modeling import FittableModel, Parameter
from astropy.units import Quantity
__all__ = ["GaussianModel", "StraightLineModel"]
[docs]
class StraightLineModel(FittableModel):
n_inputs = 1
n_outputs = 1
_input_units_allow_dimensionless = True
input_units_equivalencies = {"keV": u.spectral()}
slope = Parameter(default=1, description="Gradient of a straight line model.")
intercept = Parameter(default=0, description="Y-intercept of a straight line model.")
def __init__(self, slope=slope, intercept=intercept, edges=True, **kwargs):
self.edges = edges
super().__init__(slope, intercept, **kwargs)
[docs]
def evaluate(self, x, slope, intercept):
if self.edges:
x = x[:-1] + 0.5 * np.diff(x)
"""Evaluate the straight line model at `x` with parameters `slope` and `intercept`."""
return slope * x + intercept
@property
def input_units(self):
if isinstance(self.slope, Quantity):
return {"x": self.intercept.unit / self.slope.unit}
return None
@property
def return_units(self):
if isinstance(self.slope, Quantity):
return {"y": self.intercept.unit}
return None
def _parameter_units_for_data_units(self, input_units, output_units):
return {"slope": output_units["y"] / input_units["x"], "intercept": output_units["y"]}
[docs]
class GaussianModel(FittableModel):
n_inputs = 1
n_outputs = 1
_input_units_allow_dimensionless = True
amplitude = Parameter(default=1, min=0, description="Scalar for Gaussian.")
mean = Parameter(default=0, min=0, description="X-offset for Gaussian.")
stddev = Parameter(default=1, description="Sigma for Gaussian.")
def __init__(self, amplitude=amplitude, mean=mean, stddev=stddev, edges=True, **kwargs):
self.edges = edges
super().__init__(amplitude, mean, stddev, **kwargs)
[docs]
def evaluate(self, x, amplitude, mean, stddev):
"""Evaluate the Gaussian model at `x` with parameters `amplitude`, `mean`, and `stddev`."""
if self.edges:
x = x[:-1] + 0.5 * np.diff(x)
return amplitude * np.e ** (-((x - mean) ** 2) / (2 * stddev**2))
@property
def input_units(self):
if isinstance(self.mean, Quantity):
return {"x": self.mean.unit}
return None
@property
def return_units(self):
if isinstance(self.amplitude, Quantity):
return {"y": self.amplitude.unit}
return None
def _parameter_units_for_data_units(self, input_units, output_units):
return {"mean": input_units["x"], "stddev": input_units["x"], "amplitude": output_units["y"]}