from mmf.solve.test_problems import TwoDimensional
from mmf.solve.solver_interface import solve, IterationError
from mmf.solve.broyden_solvers import Broyden

potential = TwoDimensional.SensitiveParameter()
solver = Broyden(verbosity=0, debug=True)

#make contour plot
x = np.linspace(-1,12,1000)
y = np.linspace(-13,15,1000)
X, Y = np.meshgrid(x, y)
Z = potential.f(X, Y)
axis = [-1,12,-13,15]

plt.contour(X,Y,Z**(0.1), 10, zorder=-1000)

# First attempt: Solve along good orthogonal directions.
def f(x):
    def g(y):
        xs.append(x)         # Keep track of path.
        ys.append(y)
        return potential.f_y(x,y)
    y0[0] = solve(f=g, x0=y0[0], solver=solver).x_[0]
    y_ = np.array(zip(*solver._debug.steps)[0])
    ys_.extend(y_)
    xs_.extend(0*y_ + x)

    return potential.f(x, y0[0])

y0 = [9]
xs, ys = [], []              # All points sampled
xs_, ys_ = [], []            # Points actually tried

solve(f=f, x0=9, solver=solver).x_[0];

xs, ys, xs_, ys_ = map(lambda x: np.asarray(x).ravel(),
                       [xs, ys, xs_, ys_])
plt.plot(xs, ys, 'b.',zorder=101)
plt.plot(xs_, ys_, 'b-x',zorder=101)
dx = np.diff(xs_)
dy = np.diff(ys_)
for n in xrange(len(dx)):
    a = plt.arrow(xs_[n], ys_[n], dx[n]/2, dy[n]/2,
                  linewidth=0, head_width=0.2, zorder=100,
                  color='b')

# Now try full solving: this fails!
def F(x):
    xs.append(x[0]); ys.append(x[1])
    return np.array([potential.f_x(x[0], x[1]), potential.f_y(x[0], x[1])])

xs, ys = [], []
try:
    solve(f=F, x0=[9,9], solver=solver)
except IterationError:
    pass


plt.plot(xs, ys, 'rx', zorder=200)
xys = np.array(zip(*solver._debug.steps)[0])


plt.plot(xys[:,0], xys[:,1], 'x-r',zorder=200)
dx = np.diff(xys[:,0])
dy = np.diff(xys[:,1])
for n in xrange(len(dx)):
    a = plt.arrow(xys[n,0], xys[n,1], dx[n]/2, dy[n]/2,
                  linewidth=0, head_width=0.01, zorder=200,
                  color='r')

# Even Newton's method has difficulty:
xs, ys = [], []
x = y = 9
for n in xrange(100):
    xs.append(x); ys.append(y)
    x, y = np.array([x,y]) - np.dot(potential.Jinv(x, y), F([x,y]))

plt.plot(xs, ys, 'g:x',zorder=99)
dx = np.diff(xs)
dy = np.diff(ys)
for n in xrange(len(dx)):
    a = plt.arrow(xs[n], ys[n], dx[n]/2, dy[n]/2,
                  linewidth=0, head_width=0.2, zorder=99,
                  color='g')

plt.axis(axis)