#!/usr/bin/env python3

"""very simple gnutella path folding simulator."""

from random import random, choice
from collections import defaultdict

def randomnodes(num=8000):
    """Generate num nodes with locations between 0 and 1."""
    nodes = set()
    for n in range(num):
       n = random()
       while n in nodes:
           n = random()
       nodes.add(n)
    return list(nodes)

def generateflat(nodes, connectionspernode=10):
    """generate a randomly connected network with the given connectionspernode."""
    net = defaultdict(list)
    for node in nodes:
        for n in range(connectionspernode - len(net[node])):
            conn = choice(nodes)
            while (conn == node or 
                   node in net[conn] or 
                   net[conn][connectionspernode:]):
                conn = choice(nodes)
            net[node].append(conn)
            net[conn].append(node)
    for node, conns in net.items():
        net[node] == sorted(conns)
    return net

def closest(target, sortedlist):
    """return the node which is closest to the target."""
    dist = 1
    for n in sortedlist:
        d = abs(n - target)
        if d < dist:
            best = n
            dist = d
        else: 
            return best
    return best

def findroute(target, start, net):
    """Find the best route to the target usinggreedy routing."""
    best = start
    seen = set()
    route = []
    while not best == target:
        prev = best
        best = closest(target, net[prev])
        # never repeat a route deterministically
        if best in seen:
            best = choice(net[prev])
        route.append(best)
        seen.add(best)
    return route

def checkfold(target, prev, net, probability=0.07):
    """switch to the target location with some probability"""
    if random() > probability:
        return
    # do not switch to exactly the target to avoid duplicate entries
    # (implementation detail)
    while target in net:
        target = (target + (random()-0.5)*1.e-9) % 1
    conns = net[prev]
    # never switch with a neighbour
    if target in conns:
        return
    net[target] = conns
    del net[prev]
    for conn in conns:
        net[conn].remove(prev)
        net[conn] = sorted(net[conn] + [target])

def fold(net, num=10):
    """do num path foldings."""
    for i in range(num):
        nodes = list(net.keys())
        start = choice(nodes)
        target = choice(nodes)
        while target == start:
            target = choice(nodes)
            
        route = findroute(target, start, net)
        for prev in route:
            checkfold(target, prev, net)

def linklengths(net):
    """calculate the lengthsof all links"""
    lengths = []
    for node, targets in net.items():
        for t in targets:
            lengths.append(abs(t-node))
    return lengths

if __name__ == "__main__":
    import numpy as np
    nodes = randomnodes()
    net = generateflat(nodes)
    print (np.mean(linklengths(net)))
    fold(net)
    print (np.mean(linklengths(net)))
    fold(net)
    print (np.mean(linklengths(net)))
    fold(net)
    print (np.mean(linklengths(net)))