import os
from glob import glob

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import matplotlib.lines as mlines


NNUNET_RESULTS = {
    # 2d
    "oimhs": [0.9899, 0.8763,  0.8537, 0.9966],
    "isic": [0.8404],
    "dca1": [0.8003],
    "cbis_ddsm": [0.4329],
    "piccolo": [0.6749],
    "hil_toothseg": [0.8921],

    # 3d
    "osic_pulmofib": [0.4984, 0.8858, 0.7850],
    "leg_3d_us": [0.8943, 0.9059, 0.8865],
    "oasis": [0.9519, 0.9689, 0.9773, 0.9656],
    "micro_usp": [0.8402],
    "lgg_mri": [0.8875],
    "duke_liver": [0.9117],
}

SWINUNETR_RESULTS = {
    # 2d
    "oimhs": [0.8146, 0.5489, 0.5819, 0.9683],
    "isic": [0.8479],
    "dca1": [0.6718],
    "cbis_ddsm": [0.2377],
    "piccolo": [0.4384],
    "hil_toothseg": [0.6901],

    # 3d
    "osic_pulmofib": [0.4844, 0.8954, 0.7638],
    "leg_3d_us": [0.7266, 0.7004, 0.6299],
    "oasis": [0.8901, 0.8157, 0.9131, 0.6577],
    "micro_usp": [0.8498],
    "lgg_mri": [0.7369],
    "duke_liver": [0.8666],
}

BIOMEDPARSE_RESULTS = {
    # 2d
    "oimhs": [0.7079],  # numbers for intraretinal cysts only.
    "isic": [0.8851],
    "piccolo": [0.8459],
}

DATASET_MAPS = {
    # 2d
    "oimhs": "OIMHS (OCT)",
    "isic": "ISIC (Dermoscopy)",
    "dca1": "DCA1 (X-Ray Coronary Angiograms)",
    "cbis_ddsm": "CBIS DDSM (Mammography)",
    "piccolo": "PICCOLO (Narrow Band Imaging)",
    "hil_toothseg": "HIL ToothSeg (Dental Radiographs)",

    # 3d
    "osic_pulmofib": "OSIC PulmoFib (CT)",
    "leg_3d_us": "LEG 3D US (Ultrasound)",
    "oasis": "OASIS (MRI)",
    "micro_usp": "MicroUSP (Micro-Ultrasound)",
    "lgg_mri": "LGG MRI (Brain MRI)",
    "duke_liver": "DLDS (MRI)",
}

DATASETS_2D = ["oimhs", "isic", "dca1", "cbis_ddsm", "piccolo", "hil_toothseg"]

DATASETS_3D = ["osic_pulmofib", "leg_3d_us", "oasis", "micro_usp", "lgg_mri", "duke_liver"]

MODEL_MAPS = {
    "nnunet": "nnUNet",
    "swinunetr": "SwinUNETR",
    "full/sam": "SAM",
    "full/medico-samv2-half/wo_decoder": "MedicoSAM*",
    "full/medico-samv2-full/wo_decoder": "MedicoSAM*",
    "full/medico-samv2-full/w_decoder": r"MedicoSAM*$_{\mathrm{Dec}}$",
    "full/medsam": "MedSAM",
    "full/simplesam": "Simple FT*",
}

ROOT = "/mnt/vast-nhr/projects/cidas/cca/models/semantic_sam/v2"


def get_results(dataset_name):
    all_res, all_comb_names = [], []
    for rpath in sorted(glob(os.path.join(ROOT, "*", "*", "inference", dataset_name, "*", "results", "**", "*.csv"))):
        psplits = rpath[len(ROOT) + 1:].rsplit("/")
        ft_name, mname = psplits[0], psplits[1]
        ft_name = ft_name.split("_")[0]

        if mname == "medico-sam-1g":  # HACK: we do not get results for medico-sam trained on 1 GPU.
            continue

        res = pd.read_csv(rpath)
        score = res.iloc[0]["dice"]

        combination_name = f"{ft_name}/{mname}"
        if mname.startswith("medico-sam"):
            combination_name += f"/{psplits[-4]}"

        if combination_name in all_comb_names:
            idx = all_comb_names.index(combination_name)
            all_res[idx].at[0, "dice"].append(score)
        else:
            all_res.append(pd.DataFrame.from_dict([{"name": combination_name, "dice": [score]}]))
            all_comb_names.append(combination_name)

    all_res = pd.concat(all_res, ignore_index=True)
    return all_res


def _make_per_dataset_plot():
    results = {}
    for (dataset, nnunet_scores), (_, swinunetr_scores) in zip(NNUNET_RESULTS.items(), SWINUNETR_RESULTS.items()):
        scores = get_results(dataset)
        results[dataset] = {"nnunet": np.mean(nnunet_scores), "swinunetr": np.mean(swinunetr_scores)}
        if dataset in BIOMEDPARSE_RESULTS:
            results[dataset] = {**results[dataset], "biomedparse": np.mean(BIOMEDPARSE_RESULTS[dataset])}

        for df_val in scores.iloc:
            name = df_val["name"]
            dice = df_val["dice"]
            score = np.mean(dice)
            results[dataset][name] = score

    fig, axes = plt.subplots(4, 3, figsize=(40, 30))
    axes = axes.flatten()

    top_colors = ["#045275", "#2B6C8F", "#5093A9"]
    default_color = "#D3D3D3"

    for ax, (dataset, methods) in zip(axes, results.items()):
        methods_list = [
            "full/sam",
            "full/medsam",
            "full/simplesam",
            "full/medico-samv2-half/wo_decoder",
            "full/medico-samv2-full/wo_decoder",
            "full/medico-samv2-full/w_decoder",
        ]
        scores, neu_methods_list = [], []
        for _method in methods_list:
            if _method in methods:
                scores.append(methods[_method])
                neu_methods_list.append(_method)

        sorted_indices = np.argsort(scores)[::-1]
        bar_colors = [default_color] * len(scores)
        edge_colors = ["none"] * len(scores)
        edge_styles = ["solid"] * len(scores)

        for rank, idx in enumerate(sorted_indices[:3]):
            bar_colors[idx] = top_colors[rank]
            edge_colors[idx] = "none"

        bars = ax.bar(
            neu_methods_list, scores, color=bar_colors, edgecolor=edge_colors, linewidth=1.5
        )

        for bar, style in zip(bars, edge_styles):
            if style == "dashed":
                bar.set_linestyle("--")
                bar.set_linewidth(3)

        ax.axhline(methods.get("nnunet"), color="#DC3977", linewidth=4)
        ax.axhline(methods.get("swinunetr"), color="#7CCBA2", linewidth=4)
        if dataset in BIOMEDPARSE_RESULTS:
            kwargs = {"linestyle": "--"} if dataset == "oimhs" else {}
            ax.axhline(methods.get("biomedparse"), color="#C99833", linewidth=4, **kwargs)

        ax.set_ylim([0.2, 1])
        _xticklabels = [MODEL_MAPS[_exp] for _exp in neu_methods_list]
        ax.set_xticks(np.arange(len(neu_methods_list)))
        ax.set_xticklabels(_xticklabels, rotation=45, fontsize=26)
        ax.tick_params(axis='y', labelsize=20)

        for label, method in zip(ax.get_xticklabels(), neu_methods_list):
            if "medico-samv2-full" in method:
                label.set_fontweight("bold")

        fontdict = {"fontsize": 30}
        if dataset in ["oimhs", "isic", "dca1", "cbis_ddsm", "piccolo", "hil_toothseg"]:
            fontdict["fontstyle"] = "italic"
        else:
            fontdict["fontweight"] = "bold"

        ax.set_title(f'{DATASET_MAPS[dataset]}', fontdict=fontdict)
        ax.title.set_color("#212427")

    nnunet_line = mlines.Line2D([], [], color="#DC3977", linewidth=8, label="nnUNet")
    swinunetr_line = mlines.Line2D([], [], color="#7CCBA2", linewidth=8, label="SwinUNETR")
    biomed_line = mlines.Line2D([], [], color="#C99833", linewidth=8, label="BiomedParse")

    fig.legend(
        handles=[nnunet_line, swinunetr_line, biomed_line],
        loc="lower center", bbox_to_anchor=(0.5, 0.0175), ncol=3, fontsize=32,
    )

    plt.text(
        x=-15.5, y=2.1, s="Dice Similarity Coefficient", rotation=90, fontweight="bold", fontsize=32,
    )

    plt.subplots_adjust(hspace=0.875, wspace=0.1, bottom=0.125)
    plt.savefig("./fig_4_semantic_segmentation_per_dataset.png", bbox_inches="tight")
    plt.savefig("./fig_4_semantic_segmentation_per_dataset.svg", bbox_inches="tight")
    plt.close()


def _plot_absolute_mean_per_experiment(dim):
    methods = [
        "nnunet",
        "swinunetr",
        "full/sam",
        "full/medsam",
        "full/simplesam",
        "full/medico-samv2-half/wo_decoder",
        "full/medico-samv2-full/wo_decoder",
        "full/medico-samv2-full/w_decoder",
    ]

    results = {}
    for (dataset, nnunet_scores), (_, swinunetr_scores) in zip(NNUNET_RESULTS.items(), SWINUNETR_RESULTS.items()):
        if dim == "3d" and dataset not in DATASETS_3D:
            continue

        if dim == "2d" and dataset not in DATASETS_2D:
            continue

        scores = get_results(dataset)
        results[dataset] = {"nnunet": np.mean(nnunet_scores), "swinunetr": np.mean(swinunetr_scores)}
        for df_val in scores.iloc:
            name = df_val["name"]
            dice = df_val["dice"]
            score = np.mean(dice)
            results[dataset][name] = score

    # Calculate average over methods.
    method_sums = {}
    method_counts = {}

    for dataset, curr_methods in results.items():
        for method, score in curr_methods.items():
            method_sums[method] = method_sums.get(method, 0) + score
            method_counts[method] = method_counts.get(method, 0) + 1

    method_avgs = {m: method_sums[m] / method_counts[m] for m in method_sums}

    fig, ax = plt.subplots(figsize=(28, 15))

    top_colors = ["#045275", "#2B6C8F", "#5093A9"]
    top_methods = sorted(methods, key=method_avgs.get, reverse=True)[:3]  # get the top 3 methods.

    means = [method_avgs[_method] for _method in methods]

    edgecolors = ["None" if method in top_methods else "grey" for method in methods]

    bars = ax.bar(
        methods, means,
        edgecolor=edgecolors,
        linewidth=1.5,
        color=[top_colors[top_methods.index(_method)] if _method in top_methods else "#D3D3D3" for _method in methods],
    )

    ax.set_ylim([0, 1])
    ax.set_xticks(np.arange(len(methods)))
    _xticklabels = [MODEL_MAPS[_exp] for _exp in methods]
    ax.set_xticklabels(_xticklabels, rotation=15, fontsize=32)
    ax.tick_params(axis='y', labelsize=30)
    ax.set_ylabel('Dice Similarity Coefficient', fontsize=36, fontweight="bold")

    # NOTE: adds values on top of each bar
    for bar, mean, method in zip(bars, means, methods):
        ax.text(
            bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.01, round(mean, 4),
            ha='center', va='bottom', fontsize=32,
            color="black" if method in top_methods else "#696969",
            fontweight="bold" if method in top_methods else "normal",
        )

    for label, method in zip(ax.get_xticklabels(), methods):
        if "medico-samv2-full" in method:
            label.set_fontweight("bold")

    plt.title(f"Semantic Segmentation {dim.upper()}", fontsize=40, fontweight="bold")
    plt.savefig(f"./fig_1b_semantic_segmentation_{dim}_average.png", bbox_inches="tight")
    plt.savefig(f"./fig_1b_semantic_segmentation_{dim}_average.svg", bbox_inches="tight")
    plt.close()


def main():
    # For figure 4
    _make_per_dataset_plot()

    # For figure 1
    # _plot_absolute_mean_per_experiment(dim="2d")
    # _plot_absolute_mean_per_experiment(dim="3d")


if __name__ == "__main__":
    main()
