import astropy.units as u
import numpy as np
import matplotlib.pyplot as plt

from sunkit_spex.models.scaling import InverseSquareFluxScaling
from sunkit_spex.models.models import StraightLineModel

y_units = u.ph*u.keV**-1*u.s**-1
x_units = u.keV

ph_energies = np.arange(4, 100, 0.5)*x_units
ph_energies_centers = ph_energies[:-1] + 0.5*np.diff(ph_energies)

sim_cont = {"slope": -2*y_units/x_units, "intercept": 100*y_units}
source = StraightLineModel(**sim_cont)

plt.figure()
for i, d in enumerate([0.25,0.5,1]):
    distance =  InverseSquareFluxScaling(observer_distance=d*u.AU)
    observed = source * distance
    plt.plot(ph_energies_centers ,  observed(ph_energies), label='D = '+str(d)+' AU')
plt.loglog()
plt.legend()
plt.show()