#!/usr/bin/env python3

"""very simple gnutella path folding simulator."""

from random import random, choice, randint, shuffle
import numpy as np
import pylab as pl
import bisect
from copy import deepcopy

def randomnodes(num=1000):
    """Generate num nodes with locations between 0 and 1."""
    nodes = set()
    for n in range(num):
       n = random()
       while n in nodes:
           n = random()
       nodes.add(n)
    return sorted(list(nodes))

def generateflat(nodes, connectionspernode=10):
    """generate a randomly connected network with the given connectionspernode."""
    net = {}
    for node in nodes:
        if not node in net:
            net[node] = []
        for n in range(connectionspernode - len(net[node])):
            # find connections to others we do not know yet
            conn = choice(nodes)
            while (conn == node or
                   conn in net and (
                    conn in net[node] or
                    node in net[conn] or
                   net[conn][connectionspernode:])):
                conn = choice(nodes)
            net[node].append(conn)
            if not conn in net:
                net[conn] = []
            net[conn].append(node)
    for node, conns in net.items():
        net[node] == sorted(conns)
    return net

def closetolocrandomdist(loc, nodes, lennodes):
    """With the assumption that our locations are in a random flat
    distribution, we can do a faster lookup of the location to
    take.

    Essentially this does automatic binning with the cost of some perfection.
    """
    lo = max(0, int((loc - 10/lennodes)  * lennodes))
    hi = min(lennodes, int((loc + 10/lennodes)  * lennodes))
    return bisect.bisect_left(nodes, loc, lo=lo, hi=hi) - 1

def smallworldtargetlocations(node, connectionspernode):
    """Small world routing requires that the number number of
    connections to a certain area in the key space scales with
    1/distance."""
    maxdist = 0.5
    dists = [maxdist / (i+1) for i in range(int(connectionspernode*100))]
    locs = []
    for d in dists:
        locs.append((node + d) % 1)
        locs.append((node - d) % 1)
    shuffle(locs)
    return locs[:connectionspernode]

def generatesmallworldunclean(nodes, connectionspernode=10):
    """generate an ideal small world network with the given connectionspernode."""

    net = {}
    lennodes = len(nodes)
    for node in nodes:
        if not node in net:
            net[node] = []
        numconn = connectionspernode - len(net[node])
        for loc in smallworldtargetlocations(node, numconn):
            # find connections to others we do not know yet
            closestindex = closetolocrandomdist(loc, nodes, lennodes)
            conn = nodes[closestindex]
            while (conn == node or
                   conn in net and (
                    conn in net[node] or
                    node in net[conn] or
                   net[conn][connectionspernode+10:])):
                closestindex = (closestindex + randint(-lennodes/connectionspernode,lennodes/connectionspernode-1))%len(nodes)
                conn = nodes[closestindex]
            net[node].append(conn)
            if not conn in net:
                net[conn] = []
            net[conn].append(node)

    # we need to make sure that all connections are sorted…
    for conns in net.values():
        conns.sort()
    return net

def closest(target, sortedlist):
    """return the node which is closest to the target."""
    if not sortedlist:
        raise ValueError("Cannot find the closest node in an emtpy list.")
    dist = 1.1
    for n in sortedlist:
        d = abs(n - target)
        if d < dist:
            best = n
            dist = d
        else: 
            return best
    return best

def findroute(target, start, net):
    """Find the best route to the target usinggreedy routing."""
    best = start
    seen = set()
    route = []
    while not best == target:
        prev = best
        best = closest(target, net[prev])
        # never repeat a route deterministically
        if best in seen:
            possible = [node for node in net[prev] if not node in seen]
            if not possible:
                print ("all routes already taken from", start, "to", target, "over", prev)
                return [] # can’t reach the target from here: all routes already taken
            best = choice(possible)
        route.append(best)
        seen.add(best)
    return route

def numberofconnections(net):
    """Get the number of connections for each node."""
    numconns = []
    for n in net.values():
        numconns.append(len(n))
    return numconns

def dofold(target, prev, net): 
    """Actually do the switch to the target location."""
    # do not switch to exactly the target to avoid duplicate entries
    # (implementation detail)
    realtarget = target
    while realtarget in net or realtarget == prev:
        realtarget = (realtarget + (random()-0.5)*1.e-9) % 1
    connections = net[prev]
    net[realtarget] = connections
    del net[prev]
    for conn in connections:
        net[conn] = sorted([c for c in net[conn] if not c == prev] + [realtarget])

def checkfold(target, prev, net, probability=0.07):
    """switch to the target location with some probability"""
    if random() > probability:
        return
    dofold(target, prev, net)

def fold(net, num=100):
    """do num path foldings.

    :return: the lengths of all used routes."""
    routelengths = []
    for i in range(num):
        nodes = list(net.keys())
        start = choice(nodes)
        target = choice(nodes)
        while target == start:
            target = choice(nodes)
        route = findroute(target, start, net)
        routelen = len(route)
        routelengths.append(routelen)
        # fold all on the route except for the start and the endpoint
        for prev in route[1:-1]:
            pnet = net[prev]
            checkfold(target, prev, net)
    return routelengths

def linklengths(net):
    """calculate the lengths of all links"""
    lengths = []
    for node, targets in net.items():
        for t in targets:
            lengths.append(abs(t-node))
    return lengths

if __name__ == "__main__":
    nodes = randomnodes()
    basenet = generatesmallworldunclean(nodes)
    lensnapshots = {}
    foldperstep = 50
    for run in range(2):
        if not run: 
            net = deepcopy(basenet)
        else:
            net = generateflat(nodes)
        lensnapshots[(run,0)] = linklengths(net), [0]
        print (np.mean(lensnapshots[(0,0)][0]))
        print("===", "run", run, "===")
        for i in range(40):
            lengths = fold(net, foldperstep)
            lensnapshots[(run, i+1)] = linklengths(net), lengths
            print (np.mean(lensnapshots[(run, i+1)][0]), np.mean(lensnapshots[(run, i+1)][1]))

    for key, val in sorted(lensnapshots.items()):
        run, i = key
        linklen, routelen = val
        # only plot one in 10 results
        if i % 10: 
            continue
        pl.hist(linklen, 10000, cumulative=True, normed=True, histtype='step', label=str(run) + ", " + str(i*foldperstep) + ", " + str(np.mean(routelen)))
        pl.semilogx()
    pl.legend(loc="best")
    pl.savefig("123.png")