import mmf.math.integrate.integrate_1d.imt
imt = mmf.math.integrate.integrate_1d.imt.IMT()

def f(t, k, eta, theta, b, c):
   beta_m = 1.0/theta
   beta_mu = eta + beta_m
   if 1 == b:
      if 0 == c:
         num = np.sinh(beta_mu)
      else:
         num = np.cosh(beta_mu) + np.exp(-t-beta_m)
   elif 2 == b:
      if 0 == c:
         num = np.sinh(t + beta_m)*np.sinh(beta_mu)
      else:
         num = 1 + np.cosh(beta_mu)*np.cosh(t+beta_m)
   res = (t**k*np.sqrt(1 + t/beta_m/2)*
          num/(np.cosh(t+beta_m) + np.cosh(beta_mu))**b)
   return res

t = np.linspace(0,1,50)
for b in [1, 2]:
  for c in [0,1]:
    plt.subplot(2,2,b+2*c)
    if b == 1 and c == 0:
        plt.title("n")
    if b == 1 and c == 1:
        plt.title("n_s")
    if b == 2 and c == 0:
        plt.title("dn_dm")
    if b == 2 and c == 1:
        plt.title("dn_dmu")
    for eta in [0.01,1.0,100]:
      for theta in [0.01, 1.0, 100.0]:
        for k in [0.5,1.5,2.5]:
          a = k + 0.5
          #t0 = a + mmf.math.LambertW(a*np.exp(eta - a))
          t0 = a + eta
          f_ = lambda x: f(x, k, eta, theta,  b, c)
          phi0 = imt.phi(f_, 0, t0)(t)
          phi1 = imt.phi(f_, t0, 60*eta)(t)
          print t0/eta
          phi_max = max(phi0.max(), phi1.max())
          plt.plot(t, phi0/phi_max, '--y')
          plt.plot(t, phi1/phi_max, '--r')