# Gauss points for [-1,1]
def m_(n):
    n = np.arange(n) + 1
    m = (1.0 - (-1)**n)/n
    return np.mat(m).T

X_ = lambda x, n: np.mat(x**np.arange(n).reshape((n,1)))
def w_(x,n):
    X = X_(x,n)
    l = np.linalg.solve(X*X.T, m_(n))
    return X.T*l
def f(x):
    '''Return the error assuming the first half of x are the
    abscissa and the second half are the weights.'''
    x = np.asarray(x)
    n = int(len(x)/2)
    X = np.asarray(X_(x[:n], 2*n))
    w = np.asarray(x[n:]).ravel()
    m = np.asarray(m_(2*n)).ravel()
    return np.dot(X, w) - m
def find_n(x):
    for n in xrange(len(x)):
        if np.min(w_(x,n+1)) < 0:
            return n
    return n+1

def extend(x):
    n = int(len(x)/2)
    x = np.array([-1] + x[:n].tolist() + [1])
    x = (x[:-1] + x[1:])/2
    m = find_n(x)
    x = x.tolist() + list(w_(x,n).flat)
    return np.asarray(x)

def J(x):
    n = int(len(x)/2)
    j = np.arange(2*n, dtype=float).reshape((2*n,1))
    m = np.arange(n, dtype=float).reshape((1,n))
    w = x[n:].reshape((1,n))
    x = x[:n].reshape((1,n))
    J = np.hstack([j*x**(j-1)*w, x**j])
    J[0,:n] = 0
    return J

from mmf.solve.broyden import Broyden

def step(B):
    n = len(B.x)/2
    x = np.array(B.x)
    x[:n] = np.where(x[:n] < -1, -1, x[:n])
    x[:n] = np.where(x[:n] > 1, 1, x[:n])
    if np.min(x[n:]) < 0:
        m = find_n(x[:n])
        x[n:] = w_(x[:n],m).flat
    B.update(f(x), x)

def init(x):
    x = np.array(x)
    B = Broyden(x0=np.array(x), G0=f(x), max_step=0.01)
    h = 1e-6
    for n in xrange(len(B.x)):
        x[n] += h
        B.update_J(np.array(x), f(x))
        x[n] -= 2*h
        B.update_J(np.array(x), f(x))
        x[n] += h
    return B

def get(N):
    '''Return the `N` point quadrature.'''
    x = np.array([0,2])
    for n in xrange(1,N):
        x = extend(x)
        B = init(x)
        while 1e-12 < np.max(abs(f(B.x))):
            print n+1, np.max(abs(f(B.x))), B.x[:N], B.x[N:]
            step(B)
        x = B.x
    return x[:N], x[N:]

def getJ(N):
    '''Return the `N` point quadrature.'''
    x = np.array([0,2])
    for n in xrange(1,N):
        x = extend(x)
        while 1e-12 < np.max(abs(f(x))):
            print n+1, np.max(abs(f(x))), x[:N], x[N:]
            x = x - np.linalg.solve(J(x), f(x))
    return x[:N], x[N:]

n = 5
x0 = np.hstack([np. linspace(-1,1,n+2)[1:-1],
                np.ones(n)/n])
B = Broyden(x0=x0, G0=f(x0))