import pandas as pd
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import numpy as np
import os

def plot_sfr_distribution(file_path):
    # 1. Define the 24 column names 
    column_names = [
        'lat', 'lon', 'latc', 'lonc', 'chgt', 'tprcp', 'rr', 'sfr', 
        'flagt', 'flagr', 'flags', 'prob', 'clw', 'flagc', 'stype', 
        'yyyy', 'doy', 'hh', 'n0', 'bm', 'zeni', 'azi', 'slat', 'slon'
    ]
    
    # 2. Since the Fortran format uses widths of 10 for all variables 
    # (f10.3, f10.2, i10), set the width to 10 for all 24 columns.
    widths = [10] * len(column_names)
    
    if not os.path.exists(file_path):
        print(f"Error: Could not find '{file_path}'.")
        return

    print(f"Loading fixed-width data from {file_path}...")
    # Read the text file using pandas
    df = pd.read_fwf(file_path, widths=widths, names=column_names)
    
    # 3. Standardize Longitude to 0-360 range
    # (Just in case the file stores longitudes as -180 to 180)
    df['lon_360'] = df['lon'] % 360
    
    # 4. Filter for the bounding box and valid SFR data
    # drop any invalid missing values (often stored as -999.0 or similar)
    lat_min, lat_max = 50.0, 70.0
    lon_min, lon_max = 180.0, 240.0
    
    df_filtered = df[
        (df['lat'] >= lat_min) & (df['lat'] <= lat_max) &
        (df['lon_360'] >= lon_min) & (df['lon_360'] <= lon_max) &
        (df['sfr'] >= 0)  # Ensures only plot physical values
    ]
    
    if df_filtered.empty:
        print("No valid SFR data points found in the specified region.")
        return
        
    print(f"Plotting {len(df_filtered)} points...")

    # 5. Set up the Map
    plt.figure(figsize=(10, 7))
    ax = plt.axes(projection=ccrs.PlateCarree(central_longitude=180))
    
    # Set the map extent 
    ax.set_extent([lon_min - 360, lon_max - 360, lat_min, lat_max], crs=ccrs.PlateCarree())
    
    # 6. Scatter plot the SFR data
    # s=5 controls the dot size, adjust if the dots are too small or overlap too much
    sc = ax.scatter(df_filtered['lon_360'], df_filtered['lat'], 
                    c=df_filtered['sfr'], cmap='jet', s=5, 
                    transform=ccrs.PlateCarree(),
                    vmin=0, vmax=df_filtered['sfr'].max() or 5) # Default max to 5 if all 0
    
    # 7. Add map features for geographic context
    ax.coastlines(color='black', linewidth=1.5)
    ax.add_feature(cfeature.BORDERS, edgecolor='black', linewidth=1)
    ax.add_feature(cfeature.LAND, facecolor='lightgray', alpha=0.3)
    ax.add_feature(cfeature.OCEAN, facecolor='azure', alpha=0.3)
    
    # Gridlines and Labels
    gl = ax.gridlines(draw_labels=True, linewidth=0.5, color='gray', alpha=0.5, linestyle='--')
    gl.top_labels = False
    gl.right_labels = False
    
    # Add Colorbar
    cbar = plt.colorbar(sc, orientation='horizontal', pad=0.08, shrink=0.8)
    cbar.set_label('Snowfall Rate (sfr)')
    
    plt.title(f'SNPP SFR Distribution at 202605081430')   
    #plt.title(f'NO20 SFR Distribution at 202605081430')

    # Save and Show
    #output_img = "sfr_distribution_plot.png"
    output_img = "sfr_distribution_202605081430_SNPP.png"
    #output_img = "sfr_distribution_202605081430_NO20.png"
    plt.savefig(output_img)
    print(f"Plot saved successfully as '{output_img}'.")
    plt.show()

# --- Run the function ---
# Pass the filename 
plot_sfr_distribution('../data/SNPP-MIRV11-RAW1-V2.0-202605081430')
#plot_sfr_distribution('../data/NO20-MIRV11-RAW1-V2.0-202605081430')
