#!/usr/bin/env python3

"""very simple gnutella path folding simulator."""

from random import random, choice, randint, shuffle
import numpy as np
import bisect
from copy import deepcopy
import math

def dist(loc0, loc1):
    """return the wraparound distance of 2 nodes in the [0,1) space."""
    return min((loc0-loc1)%1, (loc1-loc0)%1)

def distances(target, nodelist):
    """calculate the distances between the target and the given list of nodes."""
    lengths = []
    for n in nodelist:
        lengths.append(dist(n,target))
    return lengths

def randomnodes(num=500):
    """Generate num nodes with locations between 0 and 1."""
    nodes = set()
    for n in range(num):
       n = random()
       while n in nodes:
           n = random()
       nodes.add(n)
    return sorted(list(nodes))

def generateflat(nodes, connectionspernode=20):
    """generate a randomly connected network with the given connectionspernode."""
    net = {}
    for node in nodes:
        if not node in net:
            net[node] = []
        for n in range(connectionspernode - len(net[node])):
            # find connections to others we do not know yet
            conn = choice(nodes)
            while (conn == node or
                   conn in net and (
                    conn in net[node] or
                    node in net[conn] or
                   net[conn][connectionspernode+2:])):
                conn = choice(nodes)
            net[node].append(conn)
            if not conn in net:
                net[conn] = []
            net[conn].append(node)
    for node, conns in net.items():
        net[node] == sorted(conns)
    return net

def closetolocrandomdist(loc, nodes, lennodes):
    """With the assumption that our locations are in a random flat
    distribution, we can do a faster lookup of the location to
    take.

    Essentially this does automatic binning with the cost of some perfection.
    """
    lo = max(0, int((loc - 10/lennodes)  * lennodes))
    hi = min(lennodes, int((loc + 10/lennodes)  * lennodes))
    return bisect.bisect_left(nodes, loc, lo=lo, hi=hi) - 1

def smallworldtargetlocations(node, connectionspernode):
    """Small world routing requires that the number number of
    connections to a certain area in the key space scales with
    1/distance."""
    maxdist = 0.5
    dists = [maxdist / (i+1) for i in range(int(connectionspernode*100))]
    locs = []
    for d in dists:
        locs.append((node + d) % 1)
        locs.append((node - d) % 1)
    shuffle(locs)
    return locs[:connectionspernode]

def generatesmallworldunclean(nodes, connectionspernode=20):
    """generate an ideal small world network with the given connectionspernode."""

    net = {}
    lennodes = len(nodes)
    for node in nodes:
        if not node in net:
            net[node] = []
        numconn = connectionspernode - len(net[node])
        for loc in smallworldtargetlocations(node, numconn):
            # find connections to others we do not know yet
            closestindex = closetolocrandomdist(loc, nodes, lennodes)
            conn = nodes[closestindex]
            while (conn == node or
                   conn in net and (
                    conn in net[node] or
                    node in net[conn] or
                   net[conn][connectionspernode+10:])):
                closestindex = (closestindex + randint(-lennodes/connectionspernode,lennodes/connectionspernode-1))%len(nodes)
                conn = nodes[closestindex]
            net[node].append(conn)
            if not conn in net:
                net[conn] = []
            net[conn].append(node)

    # we need to make sure that all connections are sorted…
    for conns in net.values():
        conns.sort()
    return net

def closest(target, sortedlist):
    """return the node which is closest to the target."""
    if not sortedlist:
        raise ValueError("Cannot find the closest node in an emtpy list.")
    dis = 1.1
    for n in sortedlist:
        d = dist(n,target)
        if d < dis:
            best = n
            dis = d
        else: 
            return best
    return best

def findroute(target, start, net, maxhtl=9999):
    """Find the best route to the target usinggreedy routing."""
    best = start
    seen = set()
    route = []
    for i in range(maxhtl):
        if best == target:
            return route
        prev = best
        best = closest(target, net[prev])
        # never repeat a route deterministically
        if best in seen:
            possible = [node for node in net[prev] if not node in seen]
            if not possible:
                #print ("all routes already taken from", start, "to", target, "over", prev)
                return [] # can’t reach the target from here: all routes already taken
            best = choice(possible)
        route.append(best)
        seen.add(best)
    return []

def numberofconnections(net):
    """Get the number of connections for each node."""
    numconns = []
    for n in net.values():
        numconns.append(len(n))
    return numconns

def smallworldbindef(numbins):
    """Create smallworld bins."""
    # first define the bin step: upper boundary.
    #: [0.0005, 0.005, 0.05, 0.5]
    return [0.5/10**((numbins-(i+1))) for i in range(numbins)]

def deviationfromsmallworld(linkdistances, numbins=5):
    """Calculate the deviation of the given local link length
    linkdistances from a small world linkdistances.

    Use logarithmic binning to represent the relevant feature correctly.

    numbins should be clearly less than the number of connections
    """
    bindefupper = smallworldbindef(numbins)
    # then get the total number of links
    numlinks = len(linkdistances)
    # and define how an ideal bin linkdistances would look
    idealbins = [1/numbins for i in range(numbins)]
    #idealbins = [i/sum(idealbins) for i in idealbins]
    # now create the real bins
    realbins = [0 for i in range(numbins)]
    for link in linkdistances:
        for n,upper in enumerate(bindefupper):
            if link < upper:
                realbins[n] += 1./numlinks
                break
    #print ("ideal", idealbins)
    #print ("real", realbins)
    #print ("def", bindefupper)
    # the difference is the cost
    return [abs(realbins[i] - idealbins[i]) for i in range(numbins)]
    # diff = sum(abs(realbins[i] - idealbins[i]) for i in range(numbins)) 
    # add additional cost for highest bin mismatch (they lose that too easily)
    # diff += abs(realbins[-1] - idealbins[-1])
    #print ([abs(realbins[i] - idealbins[i]) for i in range(numbins)])
    #print("diff", diff)
    #print()
    # return diff

def jumptotarget(target, prev, net): 
    """Actually do the switch to the target location."""
    # do not switch to exactly the target to avoid duplicate entries
    # (implementation detail)
    realtarget = target
    while realtarget in net or realtarget == prev:
        realtarget = (realtarget + (random()-0.5)*1.e-9) % 1
    connections = net[prev][:]
    net[realtarget] = connections
    del net[prev]
    for conn in connections:
        net[conn] = sorted([c for c in net[conn] if not c == prev] + [realtarget])

def switchwithtarget(target, prev, net):
    """switch the location with the target."""
    prevconnections = net[prev][:]
    targetconnections = net[target][:]
    net[target] = prevconnections
    net[prev] = targetconnections
    for conn in prevconnections:
        net[conn] = sorted([c for c in net[conn] if not c == prev] + [target])    
    for conn in targetconnections:
        net[conn] = sorted([c for c in net[conn] if not c == target] + [prev])    

def connecttotarget(target, prev, net):
    """Connect prev to the target."""
    net[prev].append(target)
    net[target].append(prev)

def disconnectfromtarget(target, prev, net):
    """Disconnect from the target."""
    net[prev].remove(target)
    net[target].remove(prev)
    
def hasbetterlinks(target, prev, net):
    """Check if the target position has better links than the
    previous."""
    conns = net[prev][:]
    old = distances(prev, conns)
    new = distances(target, conns)
    newdev = deviationfromsmallworld(new)
    olddev = deviationfromsmallworld(old)
    if sum(newdev) < sum(olddev):
        return True

def checkjump(target, prev, net):
    """Check if we should jump to the target.

    :param targetdeviation: deviation of the link distribution at the
        target from a small world distribution."""
    if hasbetterlinks(target, prev, net):
        jumptotarget(target, prev, net)
        return True

def checkswitch(target, prev, net):
    """Check if we should jump to the target.

    :param targetdeviation: deviation of the link distribution at the
        target from a small world distribution."""
    if hasbetterlinks(target, prev, net):
        if hasbetterlinks(prev, target, net):
            switchwithtarget(target, prev, net)
            return True

def checkjumphalf(target, prev, net):
    """Check if we should jump halfway to the target.

    :param targetdeviation: deviation of the link distribution at the
        target from a small world distribution."""
    conns = net[prev][:]
    target = (prev+target)/2
    old = distances(prev, conns)
    new = distances(target, conns)
    newdev = deviationfromsmallworld(new)
    olddev = deviationfromsmallworld(old)
    if sum(newdev) < sum(olddev):
        jumptotarget(target, prev, net)
        return True

def checkconnect(target, prev, net):
    """Check if we should connect to the target.

    Check: Does it improve our local links if we replace the worst
    connection with it."""
    oldconns = net[prev][:]
    olddist = distances(prev, oldconns)
    deviation = deviationfromsmallworld(olddist, numbins=5)
    bindefupper = smallworldbindef(numbins=5)
    worstidx = deviation.index(max(deviation))
    worstupper = bindefupper[worstidx]
    if worstidx:
        worstlower = bindefupper[worstidx-1]
    else:
        worstlower = 0
    baddist = [i for i in olddist if worstlower <= i < worstupper]
    if not baddist:
        return # no really bad connections
    tokill = choice(baddist)
    newdist = [i for i in olddist if i != tokill] + [dist(prev, target)]
    newdeviation = deviationfromsmallworld(newdist, numbins=5)
    if sum(newdeviation) < sum(deviation):
        connecttotarget(target, prev, net)
        # if tokill has at least half as many connections as prev,
        # disconnect it.
        tokillconn = oldconns[olddist.index(tokill)]
        enoughconnections = net[tokillconn][int(len(net[prev])/2):]
        if enoughconnections:
            disconnectfromtarget(tokillconn, prev, net)
        return True

def checkreplacebest(target, prev, net):
    """Check if we should connect to the target.

    Check: Does it improve our local links if we replace the worst
    connection with it."""
    oldconns = net[prev][:]
    olddist = distances(prev, oldconns)
    deviation = deviationfromsmallworld(olddist, numbins=5)
    bindefupper = smallworldbindef(numbins=5)
    bestidx = deviation.index(min(deviation))
    bestupper = bindefupper[bestidx]
    if bestidx:
        bestlower = bindefupper[bestidx-1]
    else:
        bestlower = 0
    gooddist = [i for i in olddist if bestlower <= i < bestupper]
    if not gooddist:
        return # no really good connections
    tokill = choice(gooddist)
    connecttotarget(target, prev, net)
    # if tokill has at least half as many connections as prev,
    # disconnect it.
    tokillconn = oldconns[olddist.index(tokill)]
    enoughconnections = net[tokillconn][int(len(net[prev])/2):]
    if enoughconnections:
        disconnectfromtarget(tokillconn, prev, net)
    return True
        
def checksimpleconnect(target, prev, net):
    """Just connect to the target and disconnect a random node, if it
    has enough connections."""
    conns = net[prev][:]
    connecttotarget(target, prev, net)
    tokillconn = choice(conns)
    enoughconnections = net[tokillconn][int(len(net[prev])/2):]
    if enoughconnections:
        disconnectfromtarget(tokillconn, prev, net)
    return True

def checkreplacelongest(target, prev, net):
    """Just connect to the target and disconnect a random node, if it
    has enough connections."""
    conns = net[prev][:]
    connecttotarget(target, prev, net)
    dist = distances(prev, conns)
    idx = dist.index(max(dist))
    tokillconn = conns[idx]
    enoughconnections = net[tokillconn][int(len(net[prev])/2):]
    if enoughconnections:
        disconnectfromtarget(tokillconn, prev, net)
    return True

def checkfold(target, prev, net, probability=0.2, strategy="connect"):
    """switch to the target location with some probability"""
    if random() > probability:
        return
    if strategy == "jump":
        return checkjump(target, prev, net)
    if strategy == "switch":
        return checkswitch(target, prev, net)
    if strategy == "jumphalf":
        return checkjumphalf(target, prev, net)
    if strategy == "connect":
        return checkconnect(target, prev, net)
    if strategy == "replacebest":
        return checkreplacebest(target, prev, net)
    if strategy == "replacelongest":
        return checkreplacelongest(target, prev, net)
    if strategy == "connectsimple":
        return checksimpleconnect(target, prev, net)

def fold(net, num=100, maxhtl=20):
    """do num path foldings.

    :return: the lengths of all used routes."""
    routelengths = []
    for i in range(num):
        nodes = list(net.keys())
        start = choice(nodes)
        target = choice(nodes)
        while target == start:
            target = choice(nodes)
        route = findroute(target, start, net)
        if not route:
            continue
        routelen = len(route)
        routelengths.append(routelen)
        if routelen > maxhtl:
            continue
        # fold all on the route except for the start and the endpoint
        for prev in route[:-1]:
            pnet = net[prev]
            didfold = checkfold(target, prev, net)
            if didfold:
                target = prev
    return routelengths

def linklengths(net):
    """calculate the lengths of all links"""
    lengths = []
    for node, targets in net.items():
        lengths.extend(distances(node, targets))
    return lengths

if __name__ == "__main__":
    nodes = randomnodes()
    basenet = generatesmallworldunclean(nodes)
    lensnapshots = {}
    foldperstep = 500
    for run in range(2):
        if not run: 
            net = deepcopy(basenet)
        else:
            net = generateflat(nodes)
        lensnapshots[(run,0)] = linklengths(net), [0]
        print (np.mean(lensnapshots[(0,0)][0]))
        print("===", "run", run, "===")
        routelengths = fold(net, 20)
        linklens = linklengths(net)
        print (np.mean(linklens), np.mean(routelengths), "±", np.std(routelengths), "succ", len([r for r in routelengths if r < 20])/len(routelengths), min(routelengths), max(routelengths), sum(deviationfromsmallworld(linklens, numbins=10)))
        for i in range(40):
            routelengths = fold(net, foldperstep)
            linklens = linklengths(net)
            lensnapshots[(run, i+1)] = linklens, routelengths
            print (np.mean(linklens), np.mean(routelengths), "±", np.std(routelengths), "succ", len([r for r in routelengths if r < 20])/len(routelengths),
                   min(routelengths), max(routelengths), sum(deviationfromsmallworld(linklens, numbins=10)))

    # now plot the data
    import pylab as pl
    for key, val in sorted(lensnapshots.items()):
        run, i = key
        linklen, routelen = val
        # only plot one in 10 results
        if i % 10: 
            continue
        pl.hist(linklen, 10000, cumulative=True, normed=True, histtype='step', label=str(run) + ", " + str(i*foldperstep) + ", " + str(np.mean(routelen)))
        pl.semilogx()
    pl.legend(loc="best")
    pl.savefig("123.png")