import xarray as xr
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
# Open the NetCDF file
ds = xr.open_dataset("CFSv2.prate.200203.ENSMEAN.fcst.nc", decode_times=False)

# Loop through all target values
for i, target_val in enumerate(ds.target.values):
    var2d = 86400*var.isel(target=i)

    # Set up figure and axis
    plt.figure(figsize=(8,6))
    ax = plt.axes(projection=ccrs.Robinson())  # try Robinson projection for global view

    # Plot data
    im = ax.pcolormesh(ds.lon, ds.lat, var2d,
                         cmap="viridis", shading="auto",
                         vmin=0, vmax=50,
                         transform=ccrs.PlateCarree())

    # Make the plot

# Add colorbar with label
cbar = plt.colorbar(im, orientation="vertical", pad=0.02, shrink=0.8)
cbar.set_label("Precipitation  (mm)")

# Labels and title
plt.xlabel("Longitude")
plt.ylabel("Latitude")
plt.title("2D Plot from 3D NetCDF (target=i)")

    # Add coastlines, borders, gridlines
    ax.coastlines()
    ax.add_feature(cfeature.BORDERS, linestyle=":")
    gl = ax.gridlines(draw_labels=True, linewidth=0.5, color="gray", alpha=0.5)
    gl.top_labels = False
    gl.right_labels = False

    # Colorbar
    cbar = plt.colorbar(mesh, orientation="horizontal", pad=0.05)
    cbar.set_label("Forecast Value")

    # Title
    plt.title(f"Forecast at target={target_val}")

    # Save figure
    plt.savefig(f"forecast_target_{int(target_val)}.png", dpi=150, bbox_inches="tight")
    plt.close(fig)   # close to free memory
