import astropy.units as u
import numpy as np
import matplotlib.pyplot as plt

from astropy.modeling.powerlaws import PowerLaw1D
from astropy.visualization import quantity_support

from sunkit_spex.models.physical.albedo import Albedo

e_edges = np.linspace(5, 550, 600) * u.keV
e_centers = e_edges[0:-1] + (0.5 * np.diff(e_edges))
source = PowerLaw1D(amplitude=1*u.ph/(u.cm*u.s), x_0=5*u.keV, alpha=3)
albedo = Albedo(energy_edges=e_edges)
observed = source | albedo

with quantity_support():
    plt.figure()
    plt.plot(e_centers,  source(e_centers), 'k', label='Source')
    for i, t in enumerate([0, 45, 90]*u.deg):
        albedo.theta = t
        plt.plot(e_centers,  observed(e_centers), '--', label=f'Observed, theta={t}', color=f'C{i+1}')
        plt.plot(e_centers,  observed(e_centers) - source(e_centers), ':',
                 label=f'Reflected, theta={t}', color=f'C{i+1}')

    plt.ylim(1e-6,  1)
    plt.xlim(5, 550)
    plt.loglog()
    plt.legend()
    plt.show()