# BNS Results

This notebook includes code to produces the figures and tables presented in
section 5.2 which covers BNS analyses.

## Imports

Import the various modules we're going to use.

In [None]:
import json
import re
from collections import defaultdict
from itertools import zip_longest
from pathlib import Path

import h5py
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tabulate
from IPython.display import HTML, display
from matplotlib.lines import Line2D
from pesummary.gw.plots.latex_labels import GWlatex_labels

from gw_smc_utils.plotting import lighten_colour, set_style

set_style()

# There are some issues with plotting when using latex
plt.rcParams["text.usetex"] = False

Path("figures").mkdir(exist_ok=True)
Path("tables").mkdir(exist_ok=True)

## Data release paths

We specify paths to the data release and relevant files.

Change these if you have downloaded the data release to a different path.

In [None]:
data_release_path = Path("../data_release/gw_smc_data_release_core")
bns_results = data_release_path / "simulated_data" / "bns_results"
summary_path = bns_results / "bns_results_summary.hdf5"

## Jensen-Shannon Divergence

Load in the JSD results from the JSON file.

Each JSON file contains the following keys:

```
{
    "res1": "path to result file 1",
    "res2": "path to result file 2",
    "base": 2,            # Base used in the calculation
    "seed": 1234,         # The random seed
    "n_samples": 5000,    # The number of samples
    "n_tests": 10,        # The number of tests for bootstrapping errors
    "jsd": {
        "chirp_mass": [
            ...
        ],
        "mass_ratio": [
            ...
        ],
        ...
    }
}
```

In [None]:
jsd_path = bns_results / "jsd_results"

jsd_files = list(jsd_path.glob("*.json"))
if not jsd_files:
    raise FileNotFoundError("No JSD files found in the specified directory.")

dets = ["2det", "3det"]
labels = [
    "aligned_with_tides",
    "aligned_without_tides",
    "precessing_with_tides",
    "precessing_without_tides",
]

jsd_data = {}
for det in dets:
    jsd_data[det] = {}
    for label in labels:
        jsd_data[det][label] = {}
        for jsd_file in jsd_files:
            if det in jsd_file.name and label in jsd_file.name:
                with open(jsd_file, "r") as f:
                    jsd_data[det][label] = json.load(f)

We collate the results into a `pandas` `DataFrame` to make handling the results easier.

We take the mean of the JSD values per-parameter.

In [None]:
jsd_df = pd.DataFrame(columns=["ndet", "tides", "spin"])
for det in dets:
    for label in labels:
        dat = {
            "ndet": det[0],
            "tides": False if "without" in label else True,
            "spin": "aligned" if "aligned" in label else "precessing",
        }
        for key, value in jsd_data[det][label]["jsd"].items():
            dat[key] = float(np.nanmean(value))
        dat = pd.DataFrame([dat])
        jsd_df = pd.concat([jsd_df, dat], ignore_index=True)

## Figures 9 & B4

We make both figures using the same code and a for loop.

In [None]:
latex_labels = {}

other_gw_labels = {
    "chi_1": "$\\chi_1$",
    "chi_2": "$\\chi_2$",
}

include_tides = [True, False]

for tides in include_tides:
    figsize = plt.rcParams["figure.figsize"].copy()
    figsize[1] = 2 * figsize[1]
    fig, axs = plt.subplots(2, 1, figsize=figsize, sharex=True, height_ratios=[1, 1.4])

    for ax, spin in zip(axs, ["aligned", "precessing"]):
        subset = jsd_df[(jsd_df["spin"] == spin) & (jsd_df["tides"] == tides)]
        all_parameters = subset.columns[3:]
        parameters = all_parameters[~subset[all_parameters].isna().all()]
        for p in parameters:
            label = GWlatex_labels.get(p, other_gw_labels.get(p, p))
            label = re.sub(r"\[.*?\]", "", label)
            latex_labels[p] = label
        yticks = np.arange(len(parameters))[::-1]
        yoffset = 0.2

        for j, ndet in enumerate([2, 3]):
            df = jsd_df[jsd_df["spin"] == spin]
            df = df[df["tides"] == tides]
            df = df[df["ndet"] == str(ndet)]
            for i, p in enumerate(parameters):
                data = 1000 * np.clip(np.array(df[p].values[0]), 0, 1)
                mean = data.mean()
                std = data.std()
                offset = yoffset * (j - 0.5)
                if spin == "aligned":
                    colour = "k"
                else:
                    colour = "k"
                if ndet == 2:
                    colour = lighten_colour(colour, 0.5)
                    marker = "^"
                else:
                    marker = "o"
                ax.errorbar(
                    mean,
                    yticks[i] + offset,
                    xerr=std,
                    fmt=marker,
                    label=latex_labels[p],
                    c=colour,
                )
            ax.set_xscale("log")
            ax.set_yticks(yticks)
            # Disable minor ticks
            ax.set_yticks([], minor=True)
            ax.set_yticklabels([latex_labels[p] for p in parameters])

    axs[0].set_title("Aligned spin")
    axs[1].set_title("Precessing spin")

    for ax in axs:
        ax.axvline(2, color="k", ls="--")
        ax.set_xlim(0.1, 30)

    legend_handles = [
        Line2D(
            [0],
            [0],
            marker="o",
            color="w",
            markerfacecolor="k",
            markersize=5,
            label="3",
            ls="",
        ),
        Line2D(
            [0],
            [0],
            marker="^",
            color="w",
            markerfacecolor=lighten_colour("k", 0.5),
            markersize=5,
            label="2",
            ls="",
        ),
    ]

    axs[1].legend(
        handles=legend_handles,
        loc="lower right",
        fontsize="small",
        title="# detectors",
        title_fontsize="small",
    )

    axs[-1].set_xlabel(r"$D_{\rm JS}$ [mbits]")
    filename = f"bns_jsd_{'with' if tides else 'without'}_tides.pdf"
    fig.savefig(f"figures/{filename}", bbox_inches="tight")
    plt.show()

## Run statistics - Tables 1 & B1

Load the data from the summary file.

In [None]:
data = {}
with h5py.File(summary_path, "r") as f:
    for sampler in f.keys():
        data[sampler] = {}
        for ndetector in f[sampler].keys():
            data[sampler][ndetector] = {}
            for key in f[sampler][ndetector].keys():
                data[sampler][ndetector][key] = {}
                for stat in f[sampler][ndetector][key].keys():
                    data[sampler][ndetector][key][stat] = f[sampler][ndetector][key][
                        stat
                    ][:]

### Formatting the data

Reformat the data into rows for the table.

This includes taking the mean of the run statistics.

In [None]:
rows = []

for sampler, sampler_data in data.items():
    for det, det_data in sampler_data.items():
        for key, key_data in det_data.items():
            if "precessing" in key.lower():
                precessing = True
            else:
                precessing = False
            if "with_tides" in key.lower():
                tides = True
            else:
                tides = False
            rows.append(
                {
                    "sampler": sampler,
                    "det": det,
                    "key": key,
                    "tides": tides,
                    "precessing": precessing,
                    "likelihood_evaluations": float(
                        np.mean(key_data["likelihood_evaluations"])
                    ),
                    "sampling_time": float(np.mean(key_data["sampling_time"] / 60)),
                    "likelihood_evaluations_per_sample": float(
                        np.mean(
                            key_data["likelihood_evaluations"] / key_data["n_samples"]
                        )
                    ),
                    "sampling_time_per_sample": float(
                        np.mean(key_data["sampling_time"] / 60 / key_data["n_samples"])
                    ),
                }
            )


df = pd.DataFrame(rows)

# Define group order
# Order: (precessing, tides)
group_order = [
    (False, False),
    (False, True),
    (True, False),
    (True, True),
]


# Helper to interleave rows
def interleave_group(group_df):
    grouped = [
        g[1].sort_values("det").to_dict(orient="records")
        for g in group_df.groupby("sampler")
    ]
    interleaved = []
    for row_group in zip_longest(*grouped):
        interleaved.extend(r for r in row_group if r is not None)
    return interleaved


# Process all groups and interleave
final_rows = []
for precessing, tides in group_order:
    group_df = df[(df["precessing"] == precessing) & (df["tides"] == tides)]
    final_rows.extend(interleave_group(group_df))

### HTML Table

Below, the tabled is visualized in HTML using the `tabulate` package.

In [None]:
html_table = tabulate.tabulate(final_rows, headers="keys", tablefmt="html")
display(HTML(html_table))

### Latex tables

We then generate the latex tables for the paper.

These are written to files in the `tables` directory.

In [None]:
table_data = defaultdict(
    lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
)

for ndet in [2, 3]:
    data = []
    for row in final_rows:
        if row["det"] == f"{ndet}det":
            continue
        else:
            data.append(row)

    def format_sci_latex(x, precision=2):
        """Format number in LaTeX-style scientific notation."""
        if x == 0:
            return f"${0:.{precision}f}$"
        exponent = int(np.floor(np.log10(abs(x))))
        mantissa = x / (10**exponent)
        return f"${mantissa:.{precision}f} \\times 10^{{{exponent}}}$"

    # Function to format a row (can average if needed)
    def format_cell(entry, keys):
        if not entry:
            return ["--"] * len(keys)
        out = []
        for key in keys:
            v = entry[key]
            if key == "likelihood_evaluations":
                out.append(format_sci_latex(v, 1))
            elif key == "sampler":
                out.append(r"\texttt{" + v + "}")
            elif isinstance(v, float):
                out.append(f"{v:.4g}")
            elif isinstance(v, bool):
                if v:
                    out.append(r"\cmark")
                else:
                    out.append(r"\xmark")

            else:
                out.append(str(v))
        return out

    # Build LaTeX table
    header = r"""
    \begin{tabular}{lcccccc}
    \toprule
    Sampler & Precession & Tides & Likelihood evaluations & Wall time [min] & Likelihood evaluations per sample \\
    \midrule
    """

    latex_rows = []
    for r in data:
        row = format_cell(
            r,
            keys=[
                "sampler",
                "precessing",
                "tides",
                "likelihood_evaluations",
                "sampling_time",
                "likelihood_evaluations_per_sample",
            ],
        )
        if r["sampler"] == "dynesty":
            latex_rows.append(r"\rowcolor{lightgrey}" + " & ".join(row) + r" \\")
        else:
            latex_rows.append(" & ".join(row) + r" \\")

    footer = r"""
    \bottomrule
    \end{tabular}
    """

    latex_table = header + "\n".join(latex_rows) + footer

    with open(
        f"tables/bns_sampling_time_vs_likelihood_evaluations_{ndet}det.tex", "w"
    ) as f:
        f.write(latex_table)