import numpy as np
import pytest
from scipy.constants import c, e

from mbtrack2 import Electron, Synchrotron


def test_synchrotron_values(local_optics):
    h = 20
    L = 100
    E0 = 1e9
    particle = Electron()
    ac = 1e-3
    U0 = 250e3
    tau = np.array([10e-3, 10e-3, 5e-3])
    tune = np.array([18.2, 10.3])
    emit = np.array([50e-9, 50e-9*0.01])
    sigma_0 = 30e-12
    sigma_delta = 1e-3
    chro = [1.0,1.0]
    
    ring = Synchrotron(h, local_optics, particle, L=L, E0=E0, ac=ac, U0=U0, tau=tau,
                       emit=emit, tune=tune, sigma_delta=sigma_delta, 
                       sigma_0=sigma_0, chro=chro)
    assert pytest.approx(ring.h) == h
    assert pytest.approx(ring.L) == L
    assert pytest.approx(ring.E0) == E0
    assert pytest.approx(ring.U0) == U0
    assert pytest.approx(ring.ac) == ac
    np.testing.assert_allclose(ring.tau, tau)
    np.testing.assert_allclose(ring.tune, tune)
    np.testing.assert_allclose(ring.emit, emit)
    assert pytest.approx(ring.sigma_0) == sigma_0
    assert pytest.approx(ring.sigma_delta) == sigma_delta
    np.testing.assert_allclose(ring.chro, chro)
    assert pytest.approx(ring.T0) == L/c
    assert pytest.approx(ring.T1) == L/c/h
    assert pytest.approx(ring.f0) == c/L
    assert pytest.approx(ring.f1) == 1/(L/c/h)
    assert pytest.approx(ring.omega0) == 2 * np.pi * c/L
    assert pytest.approx(ring.omega1) == 2 * np.pi * 1/(L/c/h)
    assert pytest.approx(ring.k1) == 2 * np.pi * 1/(L/c/h) / c
    assert pytest.approx(ring.gamma) == E0 / (particle.mass * c**2 / e)
    assert pytest.approx(ring.beta) == np.sqrt(1 - (E0 / (particle.mass * c**2 / e))**-2)
    
def test_synchrotron_mcf(demo_ring):
    demo_ring.mcf_order = [5e-4, 1e-4, 1e-3]
    assert pytest.approx(demo_ring.mcf(0.5)) == 5e-4*(0.5**2) + 1e-4*0.5 + 1e-3
    assert pytest.approx(demo_ring.eta(0.5)) == demo_ring.mcf(0.5) - 1 / (demo_ring.gamma**2)
    
def test_synchrotron_tune(demo_ring):
    tuneS = demo_ring.synchrotron_tune(1e6)
    assert pytest.approx(tuneS, rel=1e-4) == 0.0017553
    
def test_synchrotron_long_twiss(demo_ring):
    tuneS, long_alpha, long_beta, long_gamma = demo_ring.get_longitudinal_twiss(1e6, add=False)
    assert pytest.approx(tuneS, rel=1e-4) == demo_ring.synchrotron_tune(1e6)
    assert pytest.approx(long_alpha, rel=1e-4) == -0.0055146
    assert pytest.approx(long_beta, rel=1e-4) == 3.0236e-08
    assert pytest.approx(long_gamma, rel=1e-4) == 3.30736e7
    
def test_synchrotron_sigma(demo_ring):
    np.testing.assert_allclose(demo_ring.sigma(), np.array([2.23606798e-04, 2.23606798e-04, 2.23606798e-05, 2.23606798e-05]))