#!/usr/bin/env python3
"""Fetch NTP quality data from test hosts and generate comparison graphs."""

import subprocess, sys, os
import warnings
warnings.filterwarnings("ignore")

HOSTS = {
    "E_periodic_patched": "i-0bbd4d706b525b2c6",
    "F_periodic_baseline": "i-051ba016b4648399a",
    "G_nohz_patched": "i-0a1209b28c3b15800",
    "H_nohz_baseline": "i-06ff421ea34e6dc60",
}
DATA_DIR = os.path.expanduser("~/ntptest-virt/data")


def fetch_data():
    from datetime import datetime, timezone
    os.makedirs(DATA_DIR, exist_ok=True)
    for label, inst in HOSTS.items():
        out = f"{DATA_DIR}/{label}.csv"
        print(f"Fetching {label} ({inst})...", end=" ")
        r = subprocess.run(["scp", "-o", "ConnectTimeout=10",
                           f"fedora@{inst}:ntp_quality.csv", out], capture_output=True)
        if r.returncode == 0:
            lines = sum(1 for _ in open(out)) - 1
            print(f"{lines} samples")
        else:
            print("FAILED")
    with open(f"{DATA_DIR}/last_updated.txt", "w") as f:
        f.write(datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S UTC"))


def plot():
    import pandas as pd
    import matplotlib.pyplot as plt
    import matplotlib.dates as mdates

    colors = {"E_periodic_patched": "blue", "F_periodic_baseline": "cyan",
              "G_nohz_patched": "red", "H_nohz_baseline": "orange"}

    dfs = {}
    for label in HOSTS:
        path = f"{DATA_DIR}/{label}.csv"
        if os.path.exists(path):
            dfs[label] = pd.read_csv(path, parse_dates=["timestamp"])

    # Trim to when all hosts have data
    cutoff = max(df["timestamp"].min() for df in dfs.values())
    for label in dfs:
        dfs[label] = dfs[label][dfs[label]["timestamp"] >= cutoff]

    fig, axes = plt.subplots(5, 1, figsize=(14, 14), sharex=True)
    for label, df in dfs.items():
        c = colors[label]
        short = label.replace("_", " ")
        axes[0].plot(df["timestamp"], df["offset_ns"], label=short, color=c, alpha=0.7, linewidth=0.8)
        axes[1].plot(df["timestamp"], df["rms_offset_ns"], label=short, color=c, alpha=0.7, linewidth=0.8)
        axes[2].plot(df["timestamp"], (df["offset_ns"]**2).expanding().mean()**0.5,
                     label=short, color=c, alpha=0.7, linewidth=0.8)
        axes[3].plot(df["timestamp"], df["freq_ppm"], label=short, color=c, alpha=0.7, linewidth=0.8)
        axes[4].plot(df["timestamp"], df["skew_ppm"], label=short, color=c, alpha=0.7, linewidth=0.8)

    axes[0].set_ylabel("Offset (ns)"); axes[0].set_title("NTP Sync Quality — NTP Precision Accounting Test")
    axes[0].set_yscale("symlog", linthresh=10000)
    axes[1].set_ylabel("RMS Offset (ns)\n(chrony)")
    axes[2].set_ylabel("RMS Offset (ns)\n(cumulative)")
    axes[3].set_ylabel("Freq correction (PPM)")
    axes[4].set_ylabel("Skew (PPM)"); axes[4].set_xlabel("Time (UTC)")
    for ax in axes:
        ax.legend(loc="upper left", fontsize=7)
        ax.grid(True, alpha=0.3)
        ax.xaxis.set_major_formatter(mdates.DateFormatter("%m-%d %H:%M"))
    plt.tight_layout()
    out_path = f"{DATA_DIR}/ntp_quality_comparison.png"
    plt.savefig(out_path, dpi=150)
    print(f"Saved: {out_path}")
    plt.close()

    # Box plots
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    labels = list(dfs.keys())
    short_labels = [l.replace("_", "\n") for l in labels]

    for idx, (col, title) in enumerate([
        ("offset_ns", "Offset (ns)"),
        ("skew_ppm", "Skew (PPM)"),
    ]):
        data = [dfs[l][col].dropna().values for l in labels]
        bp = axes[idx].boxplot(data, labels=short_labels, patch_artist=True)
        for patch, c in zip(bp["boxes"], [colors[l] for l in labels]):
            patch.set_facecolor(c)
            patch.set_alpha(0.5)
        axes[idx].set_title(title)
        axes[idx].grid(True, alpha=0.3)
    plt.suptitle("NTP Quality Distribution Comparison", fontsize=13)
    plt.tight_layout()
    out_path = f"{DATA_DIR}/ntp_quality_boxplot.png"
    plt.savefig(out_path, dpi=150)
    print(f"Saved: {out_path}")
    plt.close()

    # CDF of absolute offset
    fig, ax = plt.subplots(figsize=(10, 6))
    for label, df in dfs.items():
        abs_offset = df["offset_ns"].abs().sort_values()
        cdf = pd.Series(range(1, len(abs_offset) + 1)) / len(abs_offset)
        ax.plot(abs_offset.values, cdf.values, label=label.replace("_", " "),
                color=colors[label], linewidth=1.5)
    ax.set_xlabel("|Offset| (ns)")
    ax.set_ylabel("CDF")
    ax.set_title("CDF of Absolute NTP Offset")
    ax.legend()
    ax.grid(True, alpha=0.3)
    ax.set_xlim(left=0)
    plt.tight_layout()
    out_path = f"{DATA_DIR}/ntp_quality_cdf.png"
    plt.savefig(out_path, dpi=150)
    print(f"Saved: {out_path}")
    plt.close()

    # Allan deviation
    fig, ax = plt.subplots(figsize=(10, 6))
    tau0 = 30  # sample interval in seconds
    for label, df in dfs.items():
        x = df["offset_ns"].values / 1e9  # convert to seconds
        n = len(x)
        taus = []
        adevs = []
        for m in [1, 2, 3, 5, 10, 20, 40, 80, 120]:
            if 2 * m >= n:
                break
            # Overlapping Allan deviation
            diffs = x[2*m:] - 2*x[m:-m] + x[:n-2*m]
            adev = (diffs**2).mean() / (2 * (m * tau0)**2)
            adevs.append(adev**0.5)
            taus.append(m * tau0)
        if taus:
            ax.loglog(taus, adevs, 'o-', label=label.replace("_", " "),
                      color=colors[label], linewidth=1.5, markersize=4)
    ax.set_xlabel("Averaging time τ (seconds)")
    ax.set_ylabel("Allan Deviation (s)")
    ax.set_title("Allan Deviation — Clock Stability vs Averaging Time")
    ax.legend()
    ax.grid(True, alpha=0.3, which="both")
    plt.tight_layout()
    out_path = f"{DATA_DIR}/ntp_quality_adev.png"
    plt.savefig(out_path, dpi=150)
    print(f"Saved: {out_path}")
    plt.close()


if __name__ == "__main__":
    if "--plot-only" not in sys.argv:
        fetch_data()
    plot()
