From e283123b3a02cd650828a372779e9c0d558021f4 Mon Sep 17 00:00:00 2001
From: Reid 'arrdem' McKenzie <me@arrdem.com>
Date: Mon, 22 Nov 2021 01:32:11 -0700
Subject: [PATCH] Fix obviously colliding ICMP stream IDs

But it looks like the underlying ICMPLIB model of receive() is broken for what I'm doing
---
 projects/aloe/BUILD                       |   1 +
 projects/aloe/src/python/aloe/__main__.py |  95 ++++++++-----
 projects/aloe/src/python/aloe/urping.py   | 163 ++++++++++++++++++++++
 3 files changed, 223 insertions(+), 36 deletions(-)
 create mode 100644 projects/aloe/src/python/aloe/urping.py

diff --git a/projects/aloe/BUILD b/projects/aloe/BUILD
index d7a236d..93104b3 100644
--- a/projects/aloe/BUILD
+++ b/projects/aloe/BUILD
@@ -9,6 +9,7 @@ zapp_binary(
         ":lib",
         py_requirement("graphviz"),
         py_requirement("icmplib"),
+        py_requirement("pytz"),
         py_requirement("requests"),
     ],
 )
diff --git a/projects/aloe/src/python/aloe/__main__.py b/projects/aloe/src/python/aloe/__main__.py
index 540a787..b6ddb34 100644
--- a/projects/aloe/src/python/aloe/__main__.py
+++ b/projects/aloe/src/python/aloe/__main__.py
@@ -12,15 +12,19 @@ from datetime import datetime, timedelta
 from itertools import cycle
 import logging
 from multiprocessing import Process, Queue
+from random import randint
 import queue
 import sys
+from time import sleep
 from typing import List
 
 import graphviz
-from icmplib import Hop, ping, traceroute
+from icmplib import Hop, traceroute
 from icmplib.utils import *
 import requests
+import pytz
 
+from .urping import ping, urping
 
 log = logging.getLogger(__name__)
 
@@ -81,6 +85,8 @@ class Topology(object):
         del self._graph[key]
         del self._nodes[key]
 
+INTERVAL = 0.5
+Z = pytz.timezone("America/Denver")
 
 def compute_topology(hostlist, topology=None):
     """Walk a series of traceroute tuples, computing a 'worst expected latency' topology from them."""
@@ -89,20 +95,27 @@ def compute_topology(hostlist, topology=None):
     for h in hostlist:
         trace = traceroute(h)
         # Restrict the trace to hosts which ICMP ping
-        trace = [e for e in trace if ping(e.address, count=1).is_alive]
+        trace = [e for e in trace if ping(e.address, interval=INTERVAL, count=3).is_alive]
         topology.add_traceroute(trace)
 
     return topology
 
 
-def pinger(host, id, queue):
+def pinger(host, queue, next=None):
     # Mokney patch the RTT tracking
     host._rtts = deque(host._rtts, maxlen=100)
+    id = randint(1, 1<<16 - 1)
+    sequence = 0
+
     while True:
-        timeout = h.avg_rtt * 2 / 1000.0  # rtt is in ms but timeout is in sec.
-        start = datetime.now()
-        res = ping(host.address, id=id, timeout=timeout, count=3)
+        timeout = min(h.avg_rtt / 1000.0, 0.5)  # rtt is in ms but timeout is in sec.
+        start = datetime.now(tz=Z)
+
+        res = ping(host.address, timeout=timeout, interval=INTERVAL, count=3, id=id, sequence=sequence)
+        sequence += res._packets_sent
+
         queue.put((start, res))
+        sleep(INTERVAL)
         if res.is_alive:
             host._rtts.extend(res._rtts)
             host._packets_sent += res._packets_sent
@@ -112,28 +125,28 @@ if __name__ == "__main__":
     logging.basicConfig(level=logging.DEBUG)
     opts, args = parser.parse_known_args()
 
-    now = start = datetime.now()
+    now = start = datetime.now(tz=Z)
     reconfigure_delay = timedelta(minutes=5)
     configure_at = now - reconfigure_delay
-    flush_delay = timedelta(seconds=5)
+    flush_delay = timedelta(seconds=1)
     flush_at = now + flush_delay
 
     recovered_duration = timedelta(seconds=5)
     dead_duration = timedelta(minutes=30)
 
     topology = None
-    id = unique_identifier()
 
     q = Queue()
     workers = {}
     last_seen = {}
+    state = {}
 
     spinner = cycle("|/-\\")
 
     with open("incidents.txt", "a") as fp:
         fp.write("RESTART\n")
         while True:
-            now = datetime.now()
+            now = datetime.now(tz=Z)
 
             if flush_at <= now:
                 fp.flush()
@@ -141,53 +154,63 @@ if __name__ == "__main__":
 
             if configure_at <= now:
                 log.info("Attempting to reconfigure network topology...")
-                topology = compute_topology(opts.hosts, topology)
-                configure_at = now + reconfigure_delay
-                log.info("Graph -\n" + topology.render())
+                try:
+                    topology = compute_topology(opts.hosts, topology)
+                    configure_at = now + reconfigure_delay
+                    log.info("Graph -\n" + topology.render())
 
-                for h in topology:
-                    if h.distance == 0:
-                        continue
+                    for h in topology:
+                        if h.distance == 0:
+                            continue
 
-                    if h.address in workers:
-                        continue
+                        if h.address in workers:
+                            continue
 
-                    else:
-                        p = workers[h.address] = Process(target=pinger, args=(h, id, q))
-                        p.start()
+                        else:
+                            n = next(iter(topology.next_hops(h.address)), None)
+                            p = workers[h.address] = Process(target=pinger, args=(h, q, n))
+                            p.start()
+
+                except Exception as e:
+                    log.exception(e)
 
             try:
-                timestamp, res = q.get(timeout=0.1)
+                # Revert to "logical now" of whenever the last ping results came in.
+                now, res = q.get(timeout=0.1)
                 last = last_seen.get(res.address)
-                delta = timestamp - last if last else None
+                delta = now - last if last else None
+
+                sys.stderr.write("\r" + next(spinner) + " " + f"ICMPResponse({res.address}, {res._rtts}, {res._packets_sent})" + " " * 20)
+                sys.stderr.flush()
 
                 if res.address not in workers:
                     pass
 
                 elif res.is_alive:
-                    last_seen[res.address] = timestamp
+                    last_seen[res.address] = now
                     if last and delta > recovered_duration:
+                        state[res.address] = True
                         fp.write(
-                            f"RECOVERED\t{res.address}\t{timestamp.isoformat()}\t{delta.total_seconds()}\n"
+                            f"RECOVERED\t{res.address}\t{now.isoformat()}\t{delta.total_seconds()}\n"
                         )
                     elif not last:
-                        fp.write(f"UP\t{res.address}\t{timestamp.isoformat()}\n")
+                        state[res.address] = True
+                        fp.write(f"UP\t{res.address}\t{now.isoformat()}\n")
 
                 elif not res.is_alive:
                     if last and delta > dead_duration:
-                        workers[h.address].terminate()
-                        del workers[h.address]
-                        del topology[h.address]
-                        del last_seen[h.address]
+                        workers[res.address].terminate()
+                        del workers[res.address]
+                        del topology[res.address]
+                        del last_seen[res.address]
+                        del state[res.address]
                         fp.write(
-                            f"DEAD\t{res.address}\t{timestamp.isoformat()}\t{delta.total_seconds()}\n"
+                            f"DEAD\t{res.address}\t{now.isoformat()}\t{delta.total_seconds()}\n"
                         )
 
-                    elif last and delta < recovered_duration:
-                        fp.write(f"WARN\t{res.address}\t{timestamp.isoformat()}\n")
-
-                    elif last and delta > recovered_duration:
-                        fp.write(f"DOWN\t{res.address}\t{timestamp.isoformat()}\n")
+                    elif last and delta > recovered_duration and state[res.address]:
+                        fp.write(f"DOWN\t{res.address}\t{now.isoformat()}\n")
+                        state[res.address] = False
 
             except queue.Empty:
                 sys.stderr.write("\r" + next(spinner))
diff --git a/projects/aloe/src/python/aloe/urping.py b/projects/aloe/src/python/aloe/urping.py
new file mode 100644
index 0000000..4fba62a
--- /dev/null
+++ b/projects/aloe/src/python/aloe/urping.py
@@ -0,0 +1,163 @@
+#!/usr/bin/env python3
+
+import logging
+from time import sleep
+import traceback
+from textwrap import indent
+from random import randint
+import sys
+
+from icmplib.exceptions import (
+    ICMPLibError,
+    TimeExceeded,
+    TimeoutExceeded,
+)
+from icmplib.models import Host, Hop, ICMPRequest, ICMPReply
+from icmplib.sockets import (
+    ICMPv4Socket,
+    ICMPv6Socket,
+)
+from icmplib.utils import *
+
+
+log = logging.getLogger(__name__)
+
+
+def better_repr(self):
+    elems = ", ".join(f"{slot}={getattr(self, slot)}" for slot in self.__slots__)
+    return f"<{type(self).__name__} ({elems})>"
+
+
+ICMPRequest.__repr__ = better_repr
+ICMPReply.__repr__ = better_repr
+
+
+def urping(address: str,
+           hops: int,
+           via: str,
+           family=None,
+           count=3,
+           fudge=4,
+           id=None,
+           interval=0.2,
+           timeout=2,
+           source=None,
+           **kwargs) -> Hop:
+    """Ur-ping by (ab)using ICMP TTLs.
+
+    Craft an ICMP message which would go one (or more) hops FARTHER than the `address` host, routed towards `via`.
+    Send `count * fudge` packets, looking for responses from `address`.
+    Responses from `address` are considered; and a `Hop` is built from those results.
+    Other responses from other hosts are discarded.
+
+    """
+
+    if is_hostname(via):
+        via = resolve(via, family)[0]
+
+    if is_hostname(address):
+        address = resolve(address, falmiy)[0]
+
+    if is_ipv6_address(via):
+        _Socket = ICMPv6Socket
+    else:
+        _Socket = ICMPv4Socket
+
+    ttl = hops
+    hop = Hop(address, 0, [], hops)
+    packets_sent = 0
+
+    with _Socket(source) as sock:
+        for _ in range(count * fudge):
+            request = ICMPRequest(
+                destination=via,
+                # Note that we act like this is a new stream with a new ID and sequence each time to try and fool multipathing.
+                id=id or unique_identifier(),
+                sequence=0,
+                ttl=ttl,
+                **kwargs)
+
+            try:
+                sock.send(request)
+                packets_sent += 1
+
+                reply = None
+                reply = sock.receive(request, timeout)
+                rtt = (reply.time - request.time) * 1000
+
+                reply.raise_for_status()
+
+                assert reply.id == request.id
+                assert reply.sequence == request.sequence
+                assert reply.source == address
+
+                hop._packets_sent += 1
+                hop._rtts.append(rtt)
+
+                if hop._packets_sent >= count:
+                    break
+
+            except AssertionError:
+                log.warning("Got response from unexpected node %s (expected %s) %r for request %4", reply.source, address, reply, request)
+
+            except (TimeoutExceeded, TimeExceeded):
+                pass
+
+            except ICMPLibError as e:
+                log.exception(e)
+                break
+
+            sleep(interval)
+
+    return hop
+
+
+def ping(address, count=4, interval=1, timeout=2, id=None, source=None,
+         family=None, privileged=True, sequence=0, **kwargs):
+    """A simple, if paranoid, ping."""
+    if is_hostname(address):
+        address = resolve(address, family)[0]
+
+    if is_ipv6_address(address):
+        _Socket = ICMPv6Socket
+    else:
+        _Socket = ICMPv4Socket
+
+    id = id or randint(1, 1<<16 - 1) & 0xFFFF
+    packets_sent = 0
+    rtts = []
+
+    with _Socket(source, privileged) as sock:
+        for base in range(count):
+            sequence = (sequence + base) & 0xFFFF
+            if base > 0:
+                sleep(interval)
+
+            request = ICMPRequest(
+                destination=address,
+                id=id,
+                sequence=sequence,
+                **kwargs)
+
+            try:
+                sock.send(request)
+                packets_sent += 1
+
+                reply = sock.receive(request, timeout)
+                reply.raise_for_status()
+
+                assert reply.id == request.id
+                assert reply.sequence == request.sequence
+                assert reply.source == address
+
+                rtt = (reply.time - request.time) * 1000
+                rtts.append(rtt)
+
+            except AssertionError as e:
+                exc_type, exc_value, exc_traceback = sys.exc_info()
+                log.warning("Got erroneous response:\n  request: %r\n  reply: %r\n  err: |\n%s", request, reply, indent(traceback.format_exc(), "    "))
+
+            except ICMPLibError:
+                pass
+
+    return Host(address, packets_sent, rtts)