#!/usr/bin/env python3

"""very simple gnutella path folding simulator."""

from random import random, choice
import numpy as np

def randomnodes(num=8000):
    """Generate num nodes with locations between 0 and 1."""
    nodes = set()
    for n in range(num):
       n = random()
       while n in nodes:
           n = random()
       nodes.add(n)
    return list(nodes)

def generateflat(nodes, connectionspernode=10):
    """generate a randomly connected network with the given connectionspernode."""
    net = {}
    for node in nodes:
        if not node in net:
            net[node] = []
        for n in range(connectionspernode - len(net[node])):
            # find connections to others we do not know yet
            conn = choice(nodes)
            while (conn == node or
                   conn in net and (
                    conn in net[node] or
                    node in net[conn] or
                   net[conn][connectionspernode:])):
                conn = choice(nodes)
            net[node].append(conn)
            if not conn in net:
                net[conn] = []
            net[conn].append(node)
    for node, conns in net.items():
        net[node] == sorted(conns)
    return net

def closest(target, sortedlist):
    """return the node which is closest to the target."""
    if not sortedlist:
        raise ValueError("Cannot find the closest node in an emtpy list.")
    dist = 1.1
    for n in sortedlist:
        d = abs(n - target)
        if d < dist:
            best = n
            dist = d
        else: 
            return best
    return best

def findroute(target, start, net):
    """Find the best route to the target usinggreedy routing."""
    best = start
    seen = set()
    route = []
    while not best == target:
        prev = best
        best = closest(target, net[prev])
        # never repeat a route deterministically
        if best in seen:
            possible = [node for node in net[prev] if not node in seen]
            if not possible:
                print ("all routes already taken from", start, "to", target, "over", prev)
                return [] # can’t reach the target from here: all routes already taken
            best = choice(possible)
        route.append(best)
        seen.add(best)
    return route

def numberofconnections(net):
    """Get the number of connections for each node."""
    numconns = []
    for n in net.values():
        numconns.append(len(n))
    return numconns

def dofold(target, prev, net): 
    """Actually do the switch to the target location."""
    # do not switch to exactly the target to avoid duplicate entries
    # (implementation detail)
    realtarget = target
    while realtarget in net or realtarget == prev:
        realtarget = (realtarget + (random()-0.5)*1.e-9) % 1
    connections = net[prev]
    net[realtarget] = connections
    del net[prev]
    for conn in connections:
        net[conn] = sorted([c for c in net[conn] if not c == prev] + [realtarget])

def checkfold(target, prev, net, probability=0.7):
    """switch to the target location with some probability"""
    if random() > probability:
        return
    dofold(target, prev, net)

def fold(net, num=10):
    """do num path foldings."""
    for i in range(num):
        nodes = list(net.keys())
        start = choice(nodes)
        target = choice(nodes)
        while target == start:
            target = choice(nodes)
            
        route = findroute(target, start, net)
        # fold all on the route except for the start and the endpoint
        for prev in route[1:-1]:
            try:
                pnet = net[prev]
            except KeyError:
                idx = route.index(prev)
                print (route[idx:], prev, target, idx, prev in route[:idx])
            checkfold(target, prev, net)
            

def linklengths(net):
    """calculate the lengthsof all links"""
    lengths = []
    for node, targets in net.items():
        for t in targets:
            lengths.append(abs(t-node))
    return lengths

if __name__ == "__main__":
    import numpy as np
    nodes = randomnodes()
    net = generateflat(nodes)
    print (np.mean(linklengths(net)))
    for i in range(100):
        fold(net, 10)
        print (np.mean(linklengths(net)))