#!/usr/bin/env python3

"""very simple gnutella path folding simulator."""

from random import random, choice, randint, shuffle
import numpy as np
import pylab as pl
import bisect
from copy import deepcopy
import math

def dist(loc0, loc1):
    """return the wraparound distance of 2 nodes in the [0,1) space."""
    return min((loc0-loc1)%1, (loc1-loc0)%1)

def distances(target, nodelist):
    """calculate the distances between the target and the given list of nodes."""
    lengths = []
    for n in nodelist:
        lengths.append(dist(n,target))
    return lengths

def randomnodes(num=20000):
    """Generate num nodes with locations between 0 and 1."""
    nodes = set()
    for n in range(num):
       n = random()
       while n in nodes:
           n = random()
       nodes.add(n)
    return sorted(list(nodes))

def generateflat(nodes, connectionspernode=20):
    """generate a randomly connected network with the given connectionspernode."""
    net = {}
    for node in nodes:
        if not node in net:
            net[node] = []
        for n in range(connectionspernode - len(net[node])):
            # find connections to others we do not know yet
            conn = choice(nodes)
            while (conn == node or
                   conn in net and (
                    conn in net[node] or
                    node in net[conn] or
                   net[conn][connectionspernode:])):
                conn = choice(nodes)
            net[node].append(conn)
            if not conn in net:
                net[conn] = []
            net[conn].append(node)
    for node, conns in net.items():
        net[node] == sorted(conns)
    return net

def closetolocrandomdist(loc, nodes, lennodes):
    """With the assumption that our locations are in a random flat
    distribution, we can do a faster lookup of the location to
    take.

    Essentially this does automatic binning with the cost of some perfection.
    """
    lo = max(0, int((loc - 10/lennodes)  * lennodes))
    hi = min(lennodes, int((loc + 10/lennodes)  * lennodes))
    return bisect.bisect_left(nodes, loc, lo=lo, hi=hi) - 1

def smallworldtargetlocations(node, connectionspernode):
    """Small world routing requires that the number number of
    connections to a certain area in the key space scales with
    1/distance."""
    maxdist = 0.5
    dists = [maxdist / (i+1) for i in range(int(connectionspernode*100))]
    locs = []
    for d in dists:
        locs.append((node + d) % 1)
        locs.append((node - d) % 1)
    shuffle(locs)
    return locs[:connectionspernode]

def generatesmallworldunclean(nodes, connectionspernode=20):
    """generate an ideal small world network with the given connectionspernode."""

    net = {}
    lennodes = len(nodes)
    for node in nodes:
        if not node in net:
            net[node] = []
        numconn = connectionspernode - len(net[node])
        for loc in smallworldtargetlocations(node, numconn):
            # find connections to others we do not know yet
            closestindex = closetolocrandomdist(loc, nodes, lennodes)
            conn = nodes[closestindex]
            while (conn == node or
                   conn in net and (
                    conn in net[node] or
                    node in net[conn] or
                   net[conn][connectionspernode+10:])):
                closestindex = (closestindex + randint(-lennodes/connectionspernode,lennodes/connectionspernode-1))%len(nodes)
                conn = nodes[closestindex]
            net[node].append(conn)
            if not conn in net:
                net[conn] = []
            net[conn].append(node)

    # we need to make sure that all connections are sorted…
    for conns in net.values():
        conns.sort()
    return net

def closest(target, sortedlist):
    """return the node which is closest to the target."""
    if not sortedlist:
        raise ValueError("Cannot find the closest node in an emtpy list.")
    dis = 1.1
    for n in sortedlist:
        d = dist(n,target)
        if d < dis:
            best = n
            dis = d
        else: 
            return best
    return best

def findroute(target, start, net):
    """Find the best route to the target usinggreedy routing."""
    best = start
    seen = set()
    route = []
    while not best == target:
        prev = best
        best = closest(target, net[prev])
        # never repeat a route deterministically
        if best in seen:
            possible = [node for node in net[prev] if not node in seen]
            if not possible:
                print ("all routes already taken from", start, "to", target, "over", prev)
                return [] # can’t reach the target from here: all routes already taken
            best = choice(possible)
        route.append(best)
        seen.add(best)
    return route

def numberofconnections(net):
    """Get the number of connections for each node."""
    numconns = []
    for n in net.values():
        numconns.append(len(n))
    return numconns

def dofold(target, prev, net): 
    """Actually do the switch to the target location."""
    # do not switch to exactly the target to avoid duplicate entries
    # (implementation detail)
    realtarget = target
    while realtarget in net or realtarget == prev:
        realtarget = (realtarget + (random()-0.5)*1.e-9) % 1
    connections = net[prev]
    net[realtarget] = connections
    del net[prev]
    for conn in connections:
        net[conn] = sorted([c for c in net[conn] if not c == prev] + [realtarget])

def deviationfromsmallworld(distribution, numbins=5):
    """Calculate the deviation of the given local link length
    distribution from a small world distribution.

    Use logarithmic binning to represent the relevant feature correctly.

    numbins should be clearly less than the number of connections
    """
    # first define the bin step: upper boundary.
    #: [0.0005, 0.005, 0.05, 0.5]
    bindefupper = [0.5/10**((numbins-(i+1))) for i in range(numbins)]
    # then get the total number of links
    numlinks = len(distribution)
    # and define how an ideal bin distribution would look
    idealbins = [1/numbins for i in range(numbins)]
    #idealbins = [i/sum(idealbins) for i in idealbins]
    # now create the real bins
    realbins = [0 for i in range(numbins)]
    for link in distribution:
        for n,upper in enumerate(bindefupper):
            if link < upper:
                realbins[n] += 1./numlinks
                break
    #print ("ideal", idealbins)
    #print ("real", realbins)
    #print ("def", bindefupper)
    # the difference is the cost
    diff = sum(abs(realbins[i] - idealbins[i]) for i in range(numbins)) 
    # add additional cost for highest bin mismatch (they lose that too easily)
    diff += abs(realbins[-1] - idealbins[-1])
    #print ([abs(realbins[i] - idealbins[i]) for i in range(numbins)])
    #print("diff", diff)
    #print()
    return diff

def checkfold(target, prev, net, probability=0.07):
    """switch to the target location with some probability"""
    if random() > probability:
        return
    conns = net[prev]
    old = distances(prev, conns)
    new = distances(target, conns)
    if deviationfromsmallworld(new) < deviationfromsmallworld(old):
        dofold(target, prev, net)

def fold(net, num=100):
    """do num path foldings.

    :return: the lengths of all used routes."""
    routelengths = []
    for i in range(num):
        nodes = list(net.keys())
        start = choice(nodes)
        target = choice(nodes)
        while target == start:
            target = choice(nodes)
        route = findroute(target, start, net)
        routelen = len(route)
        routelengths.append(routelen)
        # fold all on the route except for the start and the endpoint
        for prev in route[:-1]:
            pnet = net[prev]
            checkfold(target, prev, net)
    return routelengths

def linklengths(net):
    """calculate the lengths of all links"""
    lengths = []
    for node, targets in net.items():
        lengths.extend(distances(node, targets))
    return lengths

if __name__ == "__main__":
    nodes = randomnodes()
    basenet = generatesmallworldunclean(nodes)
    lensnapshots = {}
    foldperstep = 500
    for run in range(2):
        if not run: 
            net = deepcopy(basenet)
        else:
            net = generateflat(nodes)
        lensnapshots[(run,0)] = linklengths(net), [0]
        print (np.mean(lensnapshots[(0,0)][0]))
        print("===", "run", run, "===")
        for i in range(40):
            lengths = fold(net, foldperstep)
            lensnapshots[(run, i+1)] = linklengths(net), lengths
            print (np.mean(lensnapshots[(run, i+1)][0]), np.mean(lensnapshots[(run, i+1)][1]))

    for key, val in sorted(lensnapshots.items()):
        run, i = key
        linklen, routelen = val
        # only plot one in 10 results
        if i % 10: 
            continue
        pl.hist(linklen, 10000, cumulative=True, normed=True, histtype='step', label=str(run) + ", " + str(i*foldperstep) + ", " + str(np.mean(routelen)))
        pl.semilogx()
    pl.legend(loc="best")
    pl.savefig("123.png")