from netCDF4 import Dataset
import numpy as np
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.mpl.ticker as cticker
import cartopy.feature as cfeature
import subprocess

# -----------------
# Open NetCDF files
# -----------------
nclim = Dataset("CFSv2.prate.ENSMEAN.fcst.1991-2020_09.clim.nc")
ncurr = Dataset("CFSv2.prate.202509.ENSMEAN.fcst.varb.nc")

# -----------------
# Lon/Lat grid (global)
# -----------------
Lon_1D = np.linspace(0, 360, 361)   # -180 → 179
Lat_1D = np.linspace(90, -90, 181)
Lon, Lat = np.meshgrid(Lon_1D, Lat_1D)

# -----------------
# Africa subdomain
# -----------------
lon_mask = (Lon_1D >= 70) & (Lon_1D <= 150)
lat_mask = (Lat_1D >= 15) & (Lat_1D <= 60)

Lon_sub = Lon[np.ix_(lat_mask, lon_mask)]
Lat_sub = Lat[np.ix_(lat_mask, lon_mask)]

# -----------------
# Levels & colors
# -----------------
levels = [-15,-10, -6, -4, -2, -1, -0.5, -0.25,
           0.25, 0.5, 1, 2, 4, 6, 10, 15]

colors = [(0.47, 0.31, 0.25),
          (0.75, 0.0, 0.0),
          (1.0, 0.08, 0.00),
          (1.0, 0.38, 0.00),
          (1.0, 0.63, 0.00),
          (1.0, 0.75, 0.24),
          (1.0, 0.91, 0.48),
          (1.0, 1.00, 1.00),
          (0.78, 1.0, 0.75),
          (0.59, 0.96, 0.55),
          (0.47, 0.96, 0.45),
          (0.22, 0.82, 0.24),
          (0.12, 0.71, 0.12),
          (0.02, 0.59, 0.02),
          (0.12, 0.43, 0.92),
          (0.0, 0.0, 1.0)
       ]

# -----------------
# Detect number of leads automatically
# -----------------
nleads = min(nclim.variables["fcst"].shape[0],
             ncurr.variables["fcst"].shape[0])

print(f"Detected {nleads} forecast leads.")

#-------------------------------------------------------------------------
# Read lines from file
with open("./seed/seasons_file", "r") as f:
    lines = f.readlines()

# Store words in an array (list of lists)
words_array = [line.strip().split() for line in lines]
monic = words_array[0][0]
#-------------------------------------------------------------------------

# -----------------
# Loop over forecast leads
# -----------------
for lead in range(nleads):
    fcm = words_array[0][lead+1]

    # Compute anomaly (scale by 86400 → mm/day)
insert.txt

    # Subset to Africa
    anm_sub = anm[np.ix_(lat_mask, lon_mask)]

    # Create plot
    fig, ax = plt.subplots(
        figsize=(10, 6),
        subplot_kw={"projection": ccrs.PlateCarree()}
    )

    # Africa domain
    ax.set_extent([70, 149, 15, 60], crs=ccrs.PlateCarree())

    # Map features
    ax.coastlines(resolution="110m")
    ax.add_feature(cfeature.BORDERS, linewidth=0.5)
    ax.add_feature(cfeature.LAND, facecolor="lightgray")
    ax.add_feature(cfeature.OCEAN, facecolor="lightblue")

    # Gridlines & ticks
    ax.set_xticks(np.arange(70, 151, 10), crs=ccrs.PlateCarree())
    ax.set_yticks(np.arange(15, 61, 5), crs=ccrs.PlateCarree())
    ax.xaxis.set_major_formatter(cticker.LongitudeFormatter())
    ax.yaxis.set_major_formatter(cticker.LatitudeFormatter())
    ax.tick_params(axis='both', which='major', labelsize=8.0, direction='out')
    # Adjust font sizes
#-------------------------------------
    gl = ax.gridlines(
        crs=ccrs.PlateCarree(),
        draw_labels=True,    # show tick labels
        linewidth=0.5,
        color='gray',
        alpha=0.5,
        linestyle='--'
    )

# Control which labels show up
    gl.top_labels = False
    gl.right_labels = False
    gl.left_labels = False
    gl.bottom_labels = False
#-------------------------------------
    

    # Contour plot
    contour = ax.contourf(
        Lon_sub, Lat_sub, anm_sub,
        levels=levels,
        colors=colors,
        transform=ccrs.PlateCarree()
    )

    # Title
    plt.title(f"CFSv2 Precip Rate Anom. (mm per day)  {monic}2025   {lead} Month Lead   {fcm} Fcst", fontsize=8)

    # Colorbar
    cbar = fig.colorbar(contour, extend="both", shrink=0.8, aspect=40,
                        orientation="horizontal", pad=0.06, 
                        ax=ax, ticks=levels)
    #
    cbar.ax.tick_params(labelsize=8)   # tick labels
    cbar.ax.xaxis.set_major_formatter(plt.FuncFormatter(lambda val, pos: f'{int(val)}'))
#
#
# --- Remove end ticks entirely ---
    ticks = cbar.get_ticks()
    cbar.set_ticks(ticks[1:-1])
# --- Custom tick formatter: integer if x < -0.5 or  x > 0.5 , else float
    def custom_format(x, pos):
        return f'{int(x)}' if x < -0.5 or  x > 0.5 else f'{x:.2f}'
    cbar.formatter = plt.FuncFormatter(custom_format)
    cbar.update_ticks()  # apply form
#
#    cbar.set_label("mm/day", fontsize=9)  # label fon

    # Save file
    file_name = f"contour_lead{lead}_easia.png"
    plt.savefig(file_name, bbox_inches="tight", dpi=150)
    plt.close(fig)

    # Trim borders with ImageMagick
    subprocess.run(["convert", file_name, "-trim", file_name], check=True)

