~jb753/clusterfunc

1bb6922862579888db999058801d9c8933959c84 — James Brind 11 months ago af1c00b
Vinokur double clustering
M clusterfunc/check.py => clusterfunc/check.py +1 -1
@@ 46,7 46,7 @@ def unit_single(x, Dmin, Dmax, ERmax, rtol=1e-9):
            f"Expansion ratio max(ER)={ER.max()} exceeds target {ERmax}."
        )

    if not (ER >= 1.0 - rtol).all():
    if not (ER >= (1.0 - rtol)).all():
        raise ClusteringException(
            f"Expansion ratio min(ER)={ER.min()} less than unity."
        )

A clusterfunc/double.py => clusterfunc/double.py +238 -0
@@ 0,0 1,238 @@
"""Distribute points with symmetric clustering."""
import numpy as np
from clusterfunc.exceptions import ClusteringException
import clusterfunc.check
import clusterfunc.util
import clusterfunc.single


def _invert_sinhx_x(y):
    """Return solution x for y = sinh(x)/x in Eqns. (62-67)."""

    y1 = y - 1.0
    x_low = np.sqrt(6.0 * y1) * (
        1.0
        - 0.15 * y1
        + 0.057321429 * y1**2.0
        - 0.024907295 * y1**3.0
        + 0.0077424461 * y1**4.0
        - 0.0010794123 * y1**5.0
    )

    v = np.log(y)
    w = 1.0 / y - 0.028527431
    x_high = (
        v
        + (1.0 + 1.0 / v) * np.log(2.0 * v)
        - 0.02041793
        + 0.24902722 * w
        + 1.9496443 * w**2.0
        - 2.6294547 * w**3.0
        + 8.56795911 * w**4.0
    )

    return np.where(y < 2.7829681, x_low, x_high)


def _invert_sinx_x(y):
    """Return solution x for y = sin(x)/x from Eqns. (68-71)."""

    x_low = np.pi * (
        1.0
        - y
        + y**2.0
        - (1.0 + np.pi**2.0 / 6.0) * y**3.0
        + 6.794732 * y**4.0
        - 13.205501 * y**5.0
        + 11.726095 * y**6.0
    )

    y1 = 1.0 - y
    x_high = np.sqrt(6.0 * y1) * (
        1.0
        + 0.15 * y1
        + 0.057321429 * y1**2.0
        + 0.048774238 * y1**3.0
        - 0.053337753 * y1**4.0
        + 0.075845134 * y1**5.0
    )

    return np.where(y < 0.26938972, x_low, x_high)


def clu(ds, N):
    """Two sided analytic clustering function after Vinokur."""

    s0 = 1./N/ds[0]
    s1 = N*ds[1]

    A = np.sqrt(s0 * s1)
    B = np.sqrt(s0 / s1)

    xi = np.linspace(0.0, 1.0, N)

    if np.abs(B - 1.0) < 0.001:
        # Eqn. (52)
        u = xi * (1.0 + 2.0 * (B - 1.0) * (xi - 0.5) * (1.0 - xi))
    elif B < 1.0:
        # Solve Eqn. (49)
        Dx = _invert_sinx_x(B)
        assert np.isclose(np.sin(Dx) / Dx, B, rtol=1e-1)
        # Eqn. (50)
        u = 0.5 * (1.0 + np.tan(Dx * (xi - 0.5)) / np.tan(Dx / 2.0))
    elif B >= 1.0:
        # Solve Eqn. (46)
        Dy = _invert_sinhx_x(B)
        assert np.isclose(np.sinh(Dy) / Dy, B, rtol=1e-1)
        # Eqn. (47)
        u = 0.5 * (1.0 + np.tanh(Dy * (xi - 0.5)) / np.tanh(Dy / 2.0))
    else:
        breakpoint()
        raise Exception(f"Unexpected B={B}, s0={s0}, s1={s1}")

    t = u / (A + (1.0 - A) * u)

    # Force to unit interval
    t-=t[0]
    t/=t[-1]

    assert t[0] == 0.
    assert np.isclose(t[-1] , 1.)

    return t

def clu_free(ds, dmax, ERmax, mult=8):
    """"""

    n = 1
    maxiter = 50
    for _ in range(maxiter):
        N = mult*n + 1
        x = clu(ds, N)
        dx = np.diff(x)
        ER = clusterfunc.util.ER(x)
        if (ER < ERmax).all() and (dx <= dmax).all():
            break
        else:
            n += 1

    return x

N = 31
ds = (1e-3, 5e-3)
dmax = 0.1
x = clu(ds, N)
x = clu_free(ds, dmax, 1.2)
dx = np.diff(x)
# dx_end = 1./N/np.array([slopes[0], 1./slopes[1]])
ER = clusterfunc.util.ER(x)
import matplotlib.pyplot as plt
import clusterfunc.plot
clusterfunc.plot.plot(x)
clusterfunc.plot.plot_ER(x)
fig, ax = plt.subplots()
ax.plot(dx)
ax.plot([0.,len(x)-2], ds, 'k*')
plt.show()

def _unit_fixed_N(Dmin0, Dmin1, Dmax, ERmax, N):
    """"""
    if N < 3:
        raise ClusteringException(
            f"Not enough points N={N} to double cluster, need N>=3"
        )

    def _guess_xj(Dmin0, Dmin1, Dmax, ERmax, Na, Nb, xj):
        # Initially assume exact halves
        La = xj
        Lb = 1.0 - xj

        # Evaluate the halves
        xa = clusterfunc.single.fixed_N(0.0, La, Dmin0, Dmax, ERmax, Na)
        xb1 = clusterfunc.single.fixed_N(0.0, Lb, Dmin1, Dmax, ERmax, Nb)
        xb = 1.0 - np.flip(xb1)
        x = np.concatenate((xa, xb[1:]))

        dxa = np.diff(xa)[-1]
        dxb = np.diff(xb)[0]

        return x, dxa, dxb

    # Initial guess of splits
    frac = None
    xj = None

    g = (0.5, 0.4, 0.6, 0.3, 0.7)
    for xjg in g:
        for fracg in g:
            Na, Nb = clusterfunc.util.split_cells(N, fracg)
            try:
                _guess_xj(Dmin0, Dmin1, Dmax, ERmax, Na, Nb, xjg)
                xj = xjg
                frac = fracg
            except ClusteringException:
                continue

    if xj is None:
        raise ClusteringException('Could not find initial guess of join point')


    # Iterate on the number of cells split to equalise ER
    maxiter = 100
    k_frac = 0.1  # Constant of proportionality for changes in frac wrt ER
    print('iterating frac')
    for _ in range(maxiter):

        # Split number of cells 
        Na, Nb = clusterfunc.util.split_cells(N, frac)

        print('iterating xj')
        for i in range(maxiter):
            x, dxa, dxb = _guess_xj(Dmin0, Dmin1, Dmax, ERmax, Na, Nb, xj)

            if dxa > dxb:
                ERnow = dxa / dxb
            else:
                ERnow = dxb / dxa

            print(i,':', frac, xj, ERnow)
            if ERnow < ERmax:
                print('xj converged, breaking')
                break
            else:
                xj -= dxa - dxb
            print('new xj', xj)

        ER = clusterfunc.util.ER(x)[(0,-1),]
        dER = np.diff(ER)[0]
        if np.abs(dER)<0.01:
            print('ER converged, breaking')
            break

        frac -= np.clip(k_frac*dER,-0.05, 0.05)
        frac = np.clip(frac, 0.2, 0.8)

        print('new frac', frac)
    assert len(x) == N

    return x

def _unit_free_N(Dmin0, Dmin1, Dmax, ERmax, mult=8):
    """Double-sided clustering on the unit interval with free number of points."""

    # Start from a low guess
    n = 1
    x = None
    for _ in range(1000):
        try:
            x = _unit_fixed_N(Dmin0, Dmin1, Dmax, ERmax, n*mult + 1)
            flag = True
            break
        except ClusteringException:
            n += 1
            continue

    if x is None:
        raise ClusteringException('Could not bracked valid clustering')

    return x

M clusterfunc/plot.py => clusterfunc/plot.py +11 -10
@@ 1,5 1,6 @@
import matplotlib.pyplot as plt
import numpy as np
import clusterfunc.util


def plot(x):


@@ 25,14 26,14 @@ def plot_ER(x):
    ax.set_ylabel("Coordinate")


import clusterfunc.single
import clusterfunc.symmetric
# import clusterfunc.single
# import clusterfunc.symmetric

Dmin = 1e-1
Dmax = 2e-1
ERmax = 1.2
# x = clusterfunc.single.unit_free_N(Dmin, Dmax, ERmax, mult=1)
x = clusterfunc.symmetric.unit_free_N(Dmin, Dmax, ERmax, 1)
plot(x)
# plot_ER(x)
plt.show()
# Dmin = 1e-1
# Dmax = 2e-1
# ERmax = 1.2
# # x = clusterfunc.single.unit_free_N(Dmin, Dmax, ERmax, mult=1)
# x = clusterfunc.symmetric.unit_free_N(Dmin, Dmax, ERmax, 1)
# plot(x)
# # plot_ER(x)
# plt.show()

M clusterfunc/single.py => clusterfunc/single.py +2 -2
@@ 173,7 173,7 @@ def _unit_free_N(Dmin, Dmax, ERmax, mult=8):
    return x


def fixed_N(x0, x1, Dmin, Dmax, ERmax, N):
def fixed_N(x0, x1, Dmin, Dmax, ERmax, N, check=True):
    """Single-sided clustering between two values with with fixed number of points.

    Generate a grid vector x of length N from x0 to x1. Use geometric


@@ 207,7 207,7 @@ def fixed_N(x0, x1, Dmin, Dmax, ERmax, N):
        )
    dx = x1 - x0
    dxa = np.abs(dx)
    x = x0 + dx * _unit_fixed_N(Dmin / dxa, Dmax / dxa, ERmax, N)
    x = x0 + dx * _unit_fixed_N(Dmin / dxa, Dmax / dxa, ERmax, N, check)
    return x



M clusterfunc/util.py => clusterfunc/util.py +2 -2
@@ 1,8 1,8 @@
import numpy as np


def split_cells(N):
    Na = (N - 1) // 2 + 1
def split_cells(N, frac=0.5):
    Na = int((N - 1) * frac + 1)
    Nb = N - Na + 1
    return Na, Nb


A tests/test_double.py => tests/test_double.py +49 -0
@@ 0,0 1,49 @@
import clusterfunc.double
import clusterfunc.check
import clusterfunc.util
import numpy as np
import pytest
from clusterfunc.exceptions import ClusteringException


def test_double_unit():
    for Dmin0 in (1e-3, 1e-4):
        for Dmin1_mult in (0.5, 1., 2.):
            for ERmax in (1.1, 1.2):

                Dmin1 = Dmin0*Dmin1_mult

                for Dmax in (0.05, 0.1, 0.2, 1.0):
                    if Dmax <= Dmin0 or Dmax < Dmin1:
                        continue

                # Minimum number of points with capping
                print('--',Dmin0, Dmin1, Dmax, ERmax)
                Nmin = len(
                    clusterfunc.double._unit_free_N(Dmin0, Dmin1, Dmax, ERmax, mult=1)
                )
                print(f'Clustered OK with Nmin={Nmin}')
                # Number of points for uniform, limit to not be too huge
                Nmax = np.minimum(2*Nmin, 256)
                print(Nmax)
                # Nmax = Nmin + 8
                clusterfunc.double._unit_fixed_N(Dmin0, Dmin1, Dmax, ERmax, Nmax)
                print(f'Clustered OK with Nmax={Nmax}')
                # print(' ',Nmin, Nmax)

                # Test without capping
                for N in range(Nmin, Nmax + 1):
                    print(N)

                    # Evaluate
                    x = clusterfunc.double._unit_fixed_N(Dmin0, Dmin1, Dmax, ERmax, N)

                    dx = np.diff(x)
                    ER = clusterfunc.util.ER(x)
                    rtol = 1e-9
                    assert np.all(dx <= Dmax * (1.0 + rtol))
                    assert np.diff(np.sign(dx)).all() == 0.0
                    assert np.all(ER <= ERmax * (1.0 + rtol))


test_double_unit()

M tests/test_symmetric.py => tests/test_symmetric.py +1 -1
@@ 9,7 9,7 @@ def test_symmetric_unit():
    for Dmin in (1e-1, 1e-2, 1e-3, 1e-4, 1e-5):
        for ER in (1.02, 1.05, 1.1, 1.2):
            # Number of points for uniform, limit to not be too huge
            Nmax = np.minimum(np.floor(1.0 / Dmin).astype(int), 256)
            Nmax = np.minimum(np.floor(1.0 / Dmin).astype(int), 512)

            for Dmax in (0.05, 0.1, 0.2, 1.0):
                if Dmax <= Dmin: