{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "f30aa9dc",
   "metadata": {
    "papermill": {},
    "tags": []
   },
   "source": [
    "# ocean: fx\n",
    "\n",
    "fx ocean grid data\n",
    "\n",
    "**coordinate**: ti-u-hxy-sea\n",
    "- ti: time independent\n",
    "- u: ocean surface\n",
    "- hxy: horizontal curvlinear grids\n",
    "- sea: ocean domain"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f1f107ce",
   "metadata": {
    "papermill": {},
    "tags": [
     "remove-input"
    ]
   },
   "outputs": [],
   "source": [
    "## Import libraries\n",
    "import numpy as np\n",
    "import xarray as xr\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib as mpl\n",
    "import cartopy.feature as cfeature\n",
    "import cartopy.crs as ccrs\n",
    "import cmocean\n",
    "import sys\n",
    "import os\n",
    "import glob\n",
    "\n",
    "sys.path.append(os.getcwd())\n",
    "from utils import load_grid_vertex, read_variables, read_compound_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1227985d",
   "metadata": {
    "papermill": {},
    "tags": [
     "parameters",
     "remove-input"
    ]
   },
   "outputs": [],
   "source": [
    "# parameters for the cmorized data\n",
    "cmorout=''                  # root for cmorized data, e.g., '/scratch/$USER/cmorout'\n",
    "source_id      = ''         # model name, e.g., 'NorESM3-LM'\n",
    "experiment_id  = ''         # experiment name, e.g., 'historical', 'ssp585', 'piControl'\n",
    "variant_label  = ''         # variant label, e.g., 'r1i1p1f1'\n",
    "grid_label     = ''         # grid label, e.g., 'gn', 'gr', 'g999'\n",
    "version        = ''         # version, e.g., 'v20260601'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a29a44a",
   "metadata": {
    "tags": [
     "injected-parameters"
    ]
   },
   "outputs": [],
   "source": [
    "# Parameters\n",
    "cmorout = \"/scratch/yanchun/cmorout\"\n",
    "source_id = \"DUMMY-MODEL\"\n",
    "experiment_id = \"piControl\"\n",
    "variant_label = \"r1i1p1f1\"\n",
    "grid_label = \"g999\"\n",
    "version = \"v20260601\"\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8314513c",
   "metadata": {
    "papermill": {},
    "tags": [
     "remove-input"
    ]
   },
   "outputs": [],
   "source": [
    "# data path\n",
    "data_path = os.path.join(cmorout, source_id, experiment_id, version)\n",
    "\n",
    "# load grid\n",
    "grid_file = 'data/grid.nc'\n",
    "lat, lon, clat, clon = load_grid_vertex(grid_file)\n",
    "with xr.open_dataset(grid_file) as ds:\n",
    "    pmask = ds['pmask']\n",
    "    parea = ds['parea']\n",
    "\n",
    "parea = parea.rename({'y': 'j', 'x': 'i'})\n",
    "pmask = pmask.rename({'y': 'j', 'x': 'i'})\n",
    "# load methods for plotting and set defaults\n",
    "methods =read_variables('data/methods.txt')\n",
    "#print(methods.keys())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f06c51b7",
   "metadata": {
    "papermill": {},
    "tags": []
   },
   "source": [
    "**List of data validated:**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67a5a359",
   "metadata": {
    "papermill": {},
    "tags": [
     "remove-input"
    ]
   },
   "outputs": [],
   "source": [
    "# load compound names\n",
    "\n",
    "coords = ('ti-u-hxy-sea', 'ti-u-hxy-u', 'ti-ol-hxy-sea')\n",
    "cnames = read_compound_name('data/variables.nml')\n",
    "# examples of compound names:\n",
    "# cnames = ['ocean.tos.tavg-u-hxy-sea.mon.glb', 'ocean.ficeberg.tavg-u-hxy-sea.mon.glb']\n",
    "for cname in cnames:\n",
    "    realm = cname.split('.')[0]\n",
    "    coord = cname.split('.')[2]\n",
    "    if realm != 'ocean' or coord not in coords:\n",
    "        continue\n",
    "    else:\n",
    "        print(cname)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25072041",
   "metadata": {
    "papermill": {},
    "tags": [
     "remove-input"
    ]
   },
   "outputs": [],
   "source": [
    "# loop through compound names and plot\n",
    "for cname in cnames:\n",
    "    mth_vert = 'mean'\n",
    "    mth_ts = 'mean'\n",
    "    mth_cmap = 'mpl.colormaps[\"viridis\"]'\n",
    "    if cname not in methods.keys():\n",
    "        print(f\"{cname} not found in methods.txt, using default methods for plotting.\")\n",
    "    else:\n",
    "        if methods[cname] is not None:\n",
    "            if 'vertical' in methods[cname].keys():\n",
    "                mth_vert = methods[cname]['vertical']\n",
    "\n",
    "            if 'timeseries' in methods[cname].keys():\n",
    "                mth_ts = methods[cname]['timeseries']\n",
    "\n",
    "            if 'cmap' in methods[cname].keys():\n",
    "                mth_cmap = methods[cname]['cmap']\n",
    "\n",
    "    realm = cname.split('.')[0]\n",
    "    var = cname.split('.')[1]\n",
    "    coord = cname.split('.')[2]\n",
    "    freq = cname.split('.')[3]\n",
    "\n",
    "    if realm != 'ocean' or coord not in coords:\n",
    "        continue\n",
    "\n",
    "    data_file = var+'_'+coord+'_*_'+grid_label+'_'+source_id+'_'+experiment_id+'_'+variant_label+'.nc'\n",
    "\n",
    "    if not glob.glob(os.path.join(data_path, data_file)):\n",
    "        continue\n",
    "\n",
    "    #with xr.open_mfdataset(os.path.join(data_path, data_file)) as ds:\n",
    "    data_file = glob.glob(os.path.join(data_path, data_file))[0]\n",
    "    with xr.open_dataset(os.path.join(data_path, data_file)) as ds:\n",
    "        if var in ds:\n",
    "            data = ds[var]\n",
    "        else:\n",
    "            continue\n",
    "\n",
    "    print(f'\\033[1m{cname}\\033[0m')\n",
    "    print(f'long name: {data.long_name} ({data.units})')\n",
    "#    print(f'timeseries aggregation method for 3D->1D plot: {mth_ts}')\n",
    "\n",
    "    if coord == 'ti-ol-hxy-sea':\n",
    "        if mth_vert == 'mean':\n",
    "            data2d = data.mean(dim='lev',keep_attrs=True).where(pmask == 1)\n",
    "        elif mth_vert == 'sum':\n",
    "            data2d = data.sum(dim='lev',keep_attrs=True).where(pmask == 1)\n",
    "        else:\n",
    "            raise ValueError(f'Unsupported vertical aggregation method: {mth_vert}')\n",
    "        print(f'vertical aggregation method for 3D->2D plot: {mth_vert}')\n",
    "    else:\n",
    "        data2d = data \n",
    "\n",
    "\n",
    "    #ax1 = fig.add_subplot(1, 2, 1, projection=proj)\n",
    "    #fig, (ax1, ax2) = plt.subplots( nrows=1, ncols=2, figsize=(12, 5), gridspec_kw={'width_ratios': [2, 1]}, subplot_kw={'projection': proj})\n",
    "\n",
    "    proj = ccrs.PlateCarree(central_longitude=90.0)\n",
    "    fig, ax = plt.subplots(1, figsize=(11.7, 4), dpi=96, subplot_kw={\"projection\": proj})\n",
    "\n",
    "    # plot 2D map\n",
    "    #ax1 = fig.add_subplot(gs[0], projection=proj)\n",
    "    # Plot the mesh\n",
    "    pm = ax.pcolormesh(lon, lat, data2d, cmap=eval(mth_cmap), norm=None,\n",
    "                        transform=ccrs.PlateCarree(), shading='auto', rasterized=True)\n",
    "\n",
    "    # Add map features\n",
    "    #ax.stock_img()\n",
    "    ax.add_feature(cfeature.LAND, facecolor='lightgray')\n",
    "\n",
    "    gl = ax.gridlines(ylocs=range(-90, 90, 30), draw_labels=True)\n",
    "    gl.ylocator = mpl.ticker.FixedLocator(range(-90,90,30))\n",
    "\n",
    "    # Add colorbar\n",
    "    cb = plt.colorbar(pm, ax=ax, fraction=0.4, shrink=0.8, label='[yr]')\n",
    "    cb.set_label(label=data.units, size=14)\n",
    "    cb.ax.tick_params(labelsize=12)\n",
    "    plt.tight_layout()\n",
    "        \n",
    "    #fig, ax = plot_map2d(lon, lat, data2d.mean(dim='time'), proj='PlateCarree', norm=None, cmap=eval(mth_cmap))\n",
    "    ax.coastlines(resolution='110m')\n",
    "    ax.add_feature(cfeature.LAND, facecolor='lightgray')\n",
    "    plt.title(data2d.long_name)\n",
    "\n",
    "\n",
    "    plt.title(data2d.long_name)\n",
    "    plt.show()\n",
    "\n",
    "    del data, data2d\n",
    "    del ax, pm, cb, gl, fig"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "cmip7validate-env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.13.3"
  },
  "papermill": {
   "default_parameters": {},
   "environment_variables": {},
   "input_path": "notebooks/ocean_fx.ipynb",
   "output_path": "ocean_fx.ipynb",
   "parameters": {
    "cmorout": "/scratch/yanchun/cmorout",
    "experiment_id": "piControl",
    "grid_label": "g999",
    "source_id": "DUMMY-MODEL",
    "variant_label": "r1i1p1f1",
    "version": "v20260601"
   },
   "version": "2.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}