#!/usr/bin/env python3

"""very simple gnutella path folding simulator.

For a full check, run 

for size in 100 1000 10000 ; do for conn in 10 20 30 40 ; do for htl in 10 20 40 80 ; do for strategy in switch switchalways switchonesided jump  jumphalf jumpgolden jumpgoldensmall connect connectsimple replacebest replacelongest; do echo ===== $size nodes $conn conn $htl htl $strategy ===== ; python3 sim.py --size $size --connections $conn --perstep 600 --strategy $strategy --maxhtl $htl -o $size-$conn-htl-$htl-$strategy.png; done; done; done; done
"""

from random import random, choice, randint, shuffle
import numpy as np
import bisect
from copy import deepcopy
import math

def dist(loc0, loc1):
    """return the wraparound distance of 2 nodes in the [0,1) space."""
    return min((loc0-loc1)%1, (loc1-loc0)%1)

def distances(target, nodelist):
    """calculate the distances between the target and the given list of nodes."""
    lengths = []
    for n in nodelist:
        lengths.append(dist(n,target))
    return lengths

def randomnodes(num=8000):
    """Generate num nodes with locations between 0 and 1."""
    nodes = set()
    for n in range(num):
       n = random()
       while n in nodes:
           n = random()
       nodes.add(n)
    return sorted(list(nodes))

def generateflat(nodes, connectionspernode=20):
    """generate a randomly connected network with the given connectionspernode."""
    net = {}
    for node in nodes:
        if not node in net:
            net[node] = []
        for n in range(connectionspernode - len(net[node])):
            # find connections to others we do not know yet
            conn = choice(nodes)
            while (conn == node or
                   conn in net and (
                    conn in net[node] or
                    node in net[conn] or
                   net[conn][connectionspernode+2:])):
                conn = choice(nodes)
            net[node].append(conn)
            if not conn in net:
                net[conn] = []
            net[conn].append(node)
    for node, conns in net.items():
        net[node] == sorted(conns)
    return net

def closetolocrandomdist(loc, nodes, lennodes):
    """With the assumption that our locations are in a random flat
    distribution, we can do a faster lookup of the location to
    take.

    Essentially this does automatic binning with the cost of some perfection.
    """
    lo = max(0, int((loc - 10/lennodes)  * lennodes))
    hi = min(lennodes, int((loc + 10/lennodes)  * lennodes))
    return bisect.bisect_left(nodes, loc, lo=lo, hi=hi) - 1

def smallworldtargetlocations(node, connectionspernode):
    """Small world routing requires that the number number of
    connections to a certain area in the key space scales with
    1/distance."""
    maxdist = 0.5
    dists = [maxdist / (i+1) for i in range(int(connectionspernode*100))]
    locs = []
    for d in dists:
        locs.append((node + d) % 1)
        locs.append((node - d) % 1)
    shuffle(locs)
    return locs[:connectionspernode]

def generatesmallworldunclean(nodes, connectionspernode=20):
    """generate an ideal small world network with the given connectionspernode."""

    net = {}
    lennodes = len(nodes)
    for node in nodes:
        if not node in net:
            net[node] = []
        numconn = connectionspernode - len(net[node])
        for loc in smallworldtargetlocations(node, numconn):
            # find connections to others we do not know yet
            closestindex = closetolocrandomdist(loc, nodes, lennodes)
            conn = nodes[closestindex]
            while (conn == node or
                   conn in net and (
                    conn in net[node] or
                    node in net[conn] or
                   net[conn][connectionspernode+10:])):
                closestindex = (closestindex + randint(int(-lennodes/connectionspernode),int(lennodes/connectionspernode-1)))%len(nodes)
                conn = nodes[closestindex]
            net[node].append(conn)
            if not conn in net:
                net[conn] = []
            net[conn].append(node)

    # we need to make sure that all connections are sorted…
    for conns in net.values():
        conns.sort()
    return net

def closest(target, sortedlist):
    """return the node which is closest to the target."""
    if not sortedlist:
        raise ValueError("Cannot find the closest node in an emtpy list.")
    dis = 1.1
    for n in sortedlist:
        d = dist(n,target)
        if d < dis:
            best = n
            dis = d
        else: 
            return best
    return best

def findroute(target, start, net, maxhtl=9999):
    """Find the best route to the target usinggreedy routing."""
    best = start
    seen = set()
    route = []
    for i in range(maxhtl):
        if best == target:
            return route
        prev = best
        best = closest(target, net[prev])
        # never repeat a route deterministically
        if best in seen:
            possible = [node for node in net[prev] if not node in seen]
            if not possible:
                #print ("all routes already taken from", start, "to", target, "over", prev)
                return [] # can’t reach the target from here: all routes already taken
            best = choice(possible)
        route.append(best)
        seen.add(best)
    return []

def numberofconnections(net):
    """Get the number of connections for each node."""
    numconns = []
    for n in net.values():
        numconns.append(len(n))
    return numconns

def smallworldbindef(numbins):
    """Create smallworld bins."""
    # first define the bin step: upper boundary.
    #: [0.0005, 0.005, 0.05, 0.5]
    return [0.5/10**((numbins-(i+1))) for i in range(numbins)]

def deviationfromsmallworld(linkdistances, numbins=5):
    """Calculate the deviation of the given local link length
    linkdistances from a small world linkdistances.

    Use logarithmic binning to represent the relevant feature correctly.

    numbins should be clearly less than the number of connections
    """
    bindefupper = smallworldbindef(numbins)
    # then get the total number of links
    numlinks = len(linkdistances)
    # and define how an ideal bin linkdistances would look
    idealbins = [1/numbins for i in range(numbins)]
    #idealbins = [i/sum(idealbins) for i in idealbins]
    # now create the real bins
    realbins = [0 for i in range(numbins)]
    for link in linkdistances:
        for n,upper in enumerate(bindefupper):
            if link < upper:
                realbins[n] += 1./numlinks
                break
    #print ("ideal", idealbins)
    #print ("real", realbins)
    #print ("def", bindefupper)
    # the difference is the cost
    return [abs(realbins[i] - idealbins[i]) for i in range(numbins)]
    # diff = sum(abs(realbins[i] - idealbins[i]) for i in range(numbins)) 
    # add additional cost for highest bin mismatch (they lose that too easily)
    # diff += abs(realbins[-1] - idealbins[-1])
    #print ([abs(realbins[i] - idealbins[i]) for i in range(numbins)])
    #print("diff", diff)
    #print()
    # return diff

def jumptotarget(target, prev, net): 
    """Actually do the switch to the target location."""
    # do not switch to exactly the target to avoid duplicate entries
    # (implementation detail)
    realtarget = target
    while realtarget in net or realtarget == prev:
        realtarget = (realtarget + (random()-0.5)*1.e-9) % 1
    connections = net[prev][:]
    net[realtarget] = connections
    del net[prev]
    for conn in connections:
        net[conn] = sorted([c for c in net[conn] if not c == prev] + [realtarget])

def switchwithtarget(target, prev, net):
    """switch the location with the target."""
    prevconnections = net[prev][:]
    targetconnections = net[target][:]
    net[target] = prevconnections
    net[prev] = targetconnections
    for conn in prevconnections:
        net[conn] = sorted([c for c in net[conn] if not c == prev] + [target])    
    for conn in targetconnections:
        net[conn] = sorted([c for c in net[conn] if not c == target] + [prev])    

def connecttotarget(target, prev, net):
    """Connect prev to the target."""
    net[prev].append(target)
    net[target].append(prev)

def disconnectfromtarget(target, prev, net):
    """Disconnect from the target."""
    net[prev].remove(target)
    net[target].remove(prev)
    
def hasbetterlinks(target, prev, net):
    """Check if the target position has better links than the
    previous."""
    conns = net[prev][:]
    old = distances(prev, conns)
    new = distances(target, conns)
    newdev = deviationfromsmallworld(new)
    olddev = deviationfromsmallworld(old)
    if sum(newdev) < sum(olddev):
        return True

def checkjump(target, prev, net):
    """Check if we should jump to the target.

    :param targetdeviation: deviation of the link distribution at the
        target from a small world distribution."""
    if hasbetterlinks(target, prev, net):
        jumptotarget(target, prev, net)
        return True

def checkswitch(target, prev, net):
    """Check if we should switch positions with the target.

    :param targetdeviation: deviation of the link distribution at the
        target from a small world distribution."""
    if hasbetterlinks(target, prev, net):
        if hasbetterlinks(prev, target, net):
            switchwithtarget(target, prev, net)
            return True

def checkswitchalways(target, prev, net):
    """Always switch positions with the target.

    :param targetdeviation: deviation of the link distribution at the
        target from a small world distribution."""
    switchwithtarget(target, prev, net)
    return True

def checkswitchonesided(target, prev, net):
    """Always switch positions with the target.

    :param targetdeviation: deviation of the link distribution at the
        target from a small world distribution."""
    if hasbetterlinks(target, prev, net):
        switchwithtarget(target, prev, net)
        return True

def checkjumphalf(target, prev, net):
    """Check if we should jump halfway to the target.

    :param targetdeviation: deviation of the link distribution at the
        target from a small world distribution."""
    conns = net[prev][:]
    wrap = dist(target, prev) + 1.e-9 < abs(target - prev)
    if wrap and prev < target:
        target = ((prev + (target-1))/2)%1
    elif wrap and prev > target:
        target = (((prev-1) + (target))/2)%1
    else:
        target = (prev + (target))/2
    target = (prev+target)/2
    old = distances(prev, conns)
    new = distances(target, conns)
    newdev = deviationfromsmallworld(new)
    olddev = deviationfromsmallworld(old)
    if sum(newdev) < sum(olddev):
        jumptotarget(target, prev, net)
        return True

def checkjumpratio(target, prev, net, ratio):
    """Check if we should jump to the given ratio between prev and
    target.

    :param targetdeviation: deviation of the link distribution at the
        target from a small world distribution."""
    conns = net[prev][:]
    # if the distance through the edge is shorter, go that way
    wrap = dist(target, prev) + 1.e-9 < abs(target - prev)
    if wrap and prev < target:
        target = (prev * (1-ratio) + (target-1) * ratio)%1
    elif wrap and prev > target:
        target = ((prev-1) * (1-ratio) + target * ratio)%1
    else:
        target = prev * (1-ratio) + target * ratio
    old = distances(prev, conns)
    new = distances(target, conns)
    newdev = deviationfromsmallworld(new)
    olddev = deviationfromsmallworld(old)
    if sum(newdev) < sum(olddev):
        jumptotarget(target, prev, net)
        return True

def checkjumpgolden(target, prev, net, small=False):
    """Check if we should jump the small or large golden ratio
    (√5-1)/2 to the target (small=closer to the prev, not small=closer
    to the target).

    :param targetdeviation: deviation of the link distribution at the
        target from a small world distribution."""
    golden = (math.sqrt(5)-1)/2.0
    if small:
        return checkjumpratio(target, prev, net, 1-golden)
    else:
        return checkjumpratio(target, prev, net, golden)

def checkconnect(target, prev, net):
    """Check if we should connect to the target.

    Check: Does it improve our local links if we replace the worst
    connection with it."""
    oldconns = net[prev][:]
    olddist = distances(prev, oldconns)
    deviation = deviationfromsmallworld(olddist, numbins=5)
    bindefupper = smallworldbindef(numbins=5)
    worstidx = deviation.index(max(deviation))
    worstupper = bindefupper[worstidx]
    if worstidx:
        worstlower = bindefupper[worstidx-1]
    else:
        worstlower = 0
    baddist = [i for i in olddist if worstlower <= i < worstupper]
    if not baddist:
        return # no really bad connections
    tokill = choice(baddist)
    newdist = [i for i in olddist if i != tokill] + [dist(prev, target)]
    newdeviation = deviationfromsmallworld(newdist, numbins=5)
    if sum(newdeviation) < sum(deviation):
        connecttotarget(target, prev, net)
        # if tokill has at least half as many connections as prev,
        # disconnect it.
        tokillconn = oldconns[olddist.index(tokill)]
        enoughconnections = net[tokillconn][int(len(net[prev])/2):]
        if enoughconnections:
            disconnectfromtarget(tokillconn, prev, net)
        return True

def checkreplacebest(target, prev, net):
    """Check if we should connect to the target.

    Check: Does it improve our local links if we replace the worst
    connection with it."""
    oldconns = net[prev][:]
    olddist = distances(prev, oldconns)
    deviation = deviationfromsmallworld(olddist, numbins=5)
    bindefupper = smallworldbindef(numbins=5)
    bestidx = deviation.index(min(deviation))
    bestupper = bindefupper[bestidx]
    if bestidx:
        bestlower = bindefupper[bestidx-1]
    else:
        bestlower = 0
    gooddist = [i for i in olddist if bestlower <= i < bestupper]
    if not gooddist:
        return # no really good connections
    tokill = choice(gooddist)
    connecttotarget(target, prev, net)
    # if tokill has at least half as many connections as prev,
    # disconnect it.
    tokillconn = oldconns[olddist.index(tokill)]
    enoughconnections = net[tokillconn][int(len(net[prev])/2):]
    if enoughconnections:
        disconnectfromtarget(tokillconn, prev, net)
    return True
        
def checksimpleconnect(target, prev, net):
    """Just connect to the target and disconnect a random node, if it
    has enough connections."""
    conns = net[prev][:]
    connecttotarget(target, prev, net)
    tokillconn = choice(conns)
    enoughconnections = net[tokillconn][int(len(net[prev])/2):]
    if enoughconnections:
        disconnectfromtarget(tokillconn, prev, net)
    return True

def checkreplacelongest(target, prev, net):
    """Just connect to the target and disconnect a random node, if it
    has enough connections."""
    conns = net[prev][:]
    connecttotarget(target, prev, net)
    dist = distances(prev, conns)
    idx = dist.index(max(dist))
    tokillconn = conns[idx]
    enoughconnections = net[tokillconn][int(len(net[prev])/2):]
    if enoughconnections:
        disconnectfromtarget(tokillconn, prev, net)
    return True

def checkfold(target, prev, net, probability=0.07, strategy="jumphalf"):
    """switch to the target location with some probability"""
    if random() > probability:
        return
    if strategy == "jump":
        return checkjump(target, prev, net)
    if strategy == "switch":
        return checkswitch(target, prev, net)
    if strategy == "switchalways":
        return checkswitchalways(target, prev, net)
    if strategy == "switchonesided":
        return checkswitchonesided(target, prev, net)
    if strategy == "jumphalf":
        return checkjumphalf(target, prev, net)
    if strategy == "jumpgolden":
        return checkjumpgolden(target, prev, net)
    if strategy == "jumpgoldensmall":
        return checkjumpgolden(target, prev, net, small=True)
    if strategy == "connect":
        return checkconnect(target, prev, net)
    if strategy == "replacebest":
        return checkreplacebest(target, prev, net)
    if strategy == "replacelongest":
        return checkreplacelongest(target, prev, net)
    if strategy == "connectsimple":
        return checksimpleconnect(target, prev, net)

def fold(net, num=100, strategy="jumphalf", maxhtl=20, fullroutes=True):
    """do num path foldings.

    :param fullroutes: Calculate the full route, regardless of
        maxhtl. Only use the HTL for deciding whether to try to fold
        the route.

    :return: the lengths of all used routes."""
    routelengths = []
    for i in range(num):
        nodes = list(net.keys())
        start = choice(nodes)
        target = choice(nodes)
        while target == start:
            target = choice(nodes)
        if fullroutes:
            route = findroute(target, start, net, maxhtl=9999)
        else:
            route = findroute(target, start, net, maxhtl=maxhtl)
        if not route:
            # in case we break the route at maxhtl, we just add the
            # HTL+1 to be able to do the simplistic success stats
            # later.
            if not fullroutes:
                routelengths.append(maxhtl+1)
            continue
        routelen = len(route)
        routelengths.append(routelen)
        if routelen > maxhtl:
            continue
        # fold all on the route except for the start and the endpoint
        for prev in route[:-1]:
            pnet = net[prev]
            didfold = checkfold(target, prev, net, strategy=strategy)
            if didfold:
                target = prev
    return routelengths

def linklengths(net):
    """calculate the lengths of all links"""
    lengths = []
    for node, targets in net.items():
        lengths.extend(distances(node, targets))
    return lengths

def parse_args():
    from argparse import ArgumentParser
    parser = ArgumentParser(description="Simulate Freenet network optimization.")
    parser.add_argument("--strategy", help="The optimization strategy: switch switchalways switchonesided jump  jumphalf jumpgolden jumpgoldensmall connect connectsimple replacebest replacelongest", default="jumphalf")
    parser.add_argument("--starting-config", help="The starting configurations: flat or ideal. Take two by seperating them with a comma", default="ideal,flat")
    parser.add_argument("--size", help="The size of the network.", default=1000, type=int)
    parser.add_argument("--connections", help="The mean number of connections per node.", default=20, type=int)
    parser.add_argument("--maxhtl", help="The maximum length of routes to be successful.", default=20, type=int)
    parser.add_argument("--steps", help="The maximum number of modelsteps.", default=60, type=int)
    parser.add_argument("--perstep", help="The number of requests to run per step.", default=100, type=int)
    parser.add_argument("--nofullroutes", help="Don’t calculate full routes but break at maxhtl. This gives weaker statistics but is much faster for random networks with very long routes and very low success rate", action="store_true")
    parser.add_argument("-o", "--output", help="Filename of the pylab plot.", default="plot.png")
    return parser.parse_args()

if __name__ == "__main__":
    args = parse_args()
    nodes = randomnodes(args.size)
    lensnapshots = {}
    foldperstep = args.perstep
    runs = args.starting_config.split(",")
    for run in runs:
        if run == "ideal": 
            net = generatesmallworldunclean(nodes, args.connections)
        else:
            net = generateflat(nodes, args.connections)
        lensnapshots[(run,0)] = linklengths(net), [0]
        print (np.mean(lensnapshots[(run,0)][0]))
        print("===", run, "===")
        routelengths = fold(net, foldperstep, args.strategy, args.maxhtl, fullroutes=not args.nofullroutes)
        linklens = linklengths(net)
        print (np.mean(linklens), np.mean(routelengths), "±", np.std(routelengths), "succ", len([r for r in routelengths if r < args.maxhtl])/len(routelengths), min(routelengths), max(routelengths), sum(deviationfromsmallworld(linklens, numbins=10)))
        for i in range(args.steps):
            routelengths = fold(net, foldperstep, args.strategy, args.maxhtl, fullroutes=not args.nofullroutes)
            linklens = linklengths(net)
            lensnapshots[(run, i+1)] = linklens, routelengths
            print (np.mean(linklens), np.mean(routelengths), "±", np.std(routelengths), "succ", len([r for r in routelengths if r < args.maxhtl])/len(routelengths),
                   min(routelengths), max(routelengths), sum(deviationfromsmallworld(linklens, numbins=10)))

    print()
    # now plot the data
    import pylab as pl
    pl.rcParams['font.size'] = 5
    for key, val in sorted(lensnapshots.items()):
        run, i = key
        linklen, routelen = val
        # only plot one in 10 results
        if i % 10 and i != 1: 
            continue
        succ = len([r for r in routelen if r < args.maxhtl])/len(routelen)
        pl.hist(linklen, 10000, cumulative=True, normed=True, histtype='step', label=str(run) + ", " + "{:}".format(i*foldperstep) + ", len: " + "{:.2f}".format(np.mean(routelen)) + ", succ: " + "{:.2f}".format(succ))
        pl.semilogx()
    pl.legend(loc="upper left")
    pl.title(str(args.size) + " nodes, " + str(args.connections) + " connections per node, optimized with strategy " + args.strategy + ". succ: found in max " + str(args.maxhtl) + " steps.")
    pl.savefig(args.output)