{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "f30aa9dc",
   "metadata": {
    "papermill": {},
    "tags": []
   },
   "source": [
    "# ocean: 2D\n",
    "\n",
    "2D ocean data\n",
    "\n",
    "**coordinate**: 'tavg-u-hxy-sea', 'tavg-u-hxy-ifs', 'tmax-u-hxy-sea', 'tmin-u-hxy-sea', 'tavg-d300m-hxy-sea', 'tavg-d700m-hxy-sea', 'tavg-d2000m-hxy-sea'\n",
    "- tavg: time average\n",
    "- u: ocean surface\n",
    "- d300m: layer at 300m depth\n",
    "- d700m: layer at 700m depth\n",
    "- d2000m: layer at 2000m depth\n",
    "- hxy: horizontal curvlinear grids\n",
    "- sea: ocean domain"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f1f107ce",
   "metadata": {
    "papermill": {},
    "tags": [
     "remove-input"
    ]
   },
   "outputs": [],
   "source": [
    "## Import libraries\n",
    "import numpy as np\n",
    "import xarray as xr\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib as mpl\n",
    "import cartopy.feature as cfeature\n",
    "import cartopy.crs as ccrs\n",
    "import cmocean\n",
    "import sys\n",
    "import os\n",
    "import glob\n",
    "\n",
    "sys.path.append(os.getcwd())\n",
    "from utils import load_grid_vertex, read_variables, read_compound_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1227985d",
   "metadata": {
    "papermill": {},
    "tags": [
     "parameters",
     "remove-input"
    ]
   },
   "outputs": [],
   "source": [
    "# parameters for the cmorized data\n",
    "cmorout=''                  # root for cmorized data, e.g., '/scratch/$USER/cmorout'\n",
    "source_id      = ''         # model name, e.g., 'NorESM3-LM'\n",
    "experiment_id  = ''         # experiment name, e.g., 'historical', 'ssp585', 'piControl'\n",
    "variant_label  = ''         # variant label, e.g., 'r1i1p1f1'\n",
    "grid_label     = ''         # grid label, e.g., 'gn', 'gr', 'g999'\n",
    "version        = ''         # version, e.g., 'v20260601'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "96ceb110",
   "metadata": {
    "tags": [
     "injected-parameters"
    ]
   },
   "outputs": [],
   "source": [
    "# Parameters\n",
    "cmorout = \"/scratch/yanchun/cmorout\"\n",
    "source_id = \"DUMMY-MODEL\"\n",
    "experiment_id = \"piControl\"\n",
    "variant_label = \"r1i1p1f1\"\n",
    "grid_label = \"g999\"\n",
    "version = \"v20260601\"\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8314513c",
   "metadata": {
    "papermill": {},
    "tags": [
     "remove-input"
    ]
   },
   "outputs": [],
   "source": [
    "# data path\n",
    "data_path = os.path.join(cmorout, source_id, experiment_id, version)\n",
    "\n",
    "# load grid\n",
    "grid_file = 'data/grid.nc'\n",
    "lat, lon, clat, clon = load_grid_vertex(grid_file)\n",
    "with xr.open_dataset(grid_file) as ds:\n",
    "    pmask = ds['pmask']\n",
    "    parea = ds['parea']\n",
    "\n",
    "parea = parea.rename({'y': 'j', 'x': 'i'})\n",
    "pmask = pmask.rename({'y': 'j', 'x': 'i'})\n",
    "# load methods for plotting and set defaults\n",
    "methods =read_variables('data/methods.txt')\n",
    "#print(methods.keys())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f06c51b7",
   "metadata": {
    "papermill": {},
    "tags": []
   },
   "source": [
    "**List of data validated:**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67a5a359",
   "metadata": {
    "papermill": {},
    "tags": [
     "remove-input"
    ]
   },
   "outputs": [],
   "source": [
    "# load compound names\n",
    "\n",
    "coords = ('tavg-u-hxy-sea', 'tavg-u-hxy-ifs', 'tmax-u-hxy-sea', 'tmin-u-hxy-sea', 'tavg-d300m-hxy-sea', 'tavg-d700m-hxy-sea', 'tavg-d2000m-hxy-sea')\n",
    "cnames = read_compound_name('data/variables.nml')\n",
    "# examples of compound names:\n",
    "# cnames = ['ocean.tos.tavg-u-hxy-sea.mon.glb', 'ocean.ficeberg.tavg-u-hxy-sea.mon.glb']\n",
    "for cname in cnames:\n",
    "    realm = cname.split('.')[0]\n",
    "    coord = cname.split('.')[2]\n",
    "    if realm != 'ocean' or coord not in coords:\n",
    "        continue\n",
    "    else:\n",
    "        print(cname)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25072041",
   "metadata": {
    "papermill": {},
    "tags": [
     "remove-input"
    ]
   },
   "outputs": [],
   "source": [
    "# loop through compound names and plot\n",
    "for cname in cnames:\n",
    "#    mth_vert = 'mean'\n",
    "    mth_ts = 'mean'\n",
    "    mth_cmap = 'mpl.colormaps[\"viridis\"]'\n",
    "    if cname not in methods.keys():\n",
    "        print(f\"{cname} not found in methods.txt, using default methods for plotting.\")\n",
    "    else:\n",
    "        if methods[cname] is not None:\n",
    "#            if 'vertical' in methods[cname].keys():\n",
    "#                mth_vert = methods[cname]['vertical']\n",
    "\n",
    "            if 'timeseries' in methods[cname].keys():\n",
    "                mth_ts = methods[cname]['timeseries']\n",
    "\n",
    "            if 'cmap' in methods[cname].keys():\n",
    "                mth_cmap = methods[cname]['cmap']\n",
    "\n",
    "    realm = cname.split('.')[0]\n",
    "    var = cname.split('.')[1]\n",
    "    coord = cname.split('.')[2]\n",
    "    freq = cname.split('.')[3]\n",
    "\n",
    "    if realm != 'ocean' or coord not in coords:\n",
    "        continue\n",
    "\n",
    "    data_file = var+'_'+coord+'_*_'+grid_label+'_'+source_id+'_'+experiment_id+'_'+variant_label+'_*.nc'\n",
    "\n",
    "    if not glob.glob(os.path.join(data_path, data_file)):\n",
    "        continue\n",
    "\n",
    "    #with xr.open_mfdataset(os.path.join(data_path, data_file)) as ds:\n",
    "    data_file = glob.glob(os.path.join(data_path, data_file))[0]\n",
    "    with xr.open_dataset(os.path.join(data_path, data_file)) as ds:\n",
    "        if var in ds:\n",
    "            data = ds[var]\n",
    "        else:\n",
    "            continue\n",
    "\n",
    "    data2d = data \n",
    "\n",
    "    print(f'\\033[1m{cname}\\033[0m')\n",
    "    print(f'long name: {data2d.long_name} ({data2d.units})')\n",
    "#    print(f'vertical aggregation method for 3D->2D plot: {mth_vert}')\n",
    "    print(f'timeseries aggregation method for 3D->1D plot: {mth_ts}')\n",
    "\n",
    "    #ax1 = fig.add_subplot(1, 2, 1, projection=proj)\n",
    "    #fig, (ax1, ax2) = plt.subplots( nrows=1, ncols=2, figsize=(12, 5), gridspec_kw={'width_ratios': [2, 1]}, subplot_kw={'projection': proj})\n",
    "\n",
    "    fig = plt.figure(figsize=(16, 3), dpi=96)\n",
    "    gs = fig.add_gridspec(nrows=1, ncols=2, width_ratios=[2, 1]) \n",
    "    proj = ccrs.PlateCarree(central_longitude=90.0)\n",
    "\n",
    "    # plot 2D map\n",
    "    ax1 = fig.add_subplot(gs[0], projection=proj)\n",
    "    # Plot the mesh\n",
    "    pm = ax1.pcolormesh(lon, lat, data2d.mean(dim='time'), cmap=eval(mth_cmap), norm=None,\n",
    "                        transform=ccrs.PlateCarree(), shading='auto', rasterized=True)\n",
    "\n",
    "    # Add map features\n",
    "    #ax.stock_img()\n",
    "    ax1.add_feature(cfeature.LAND, facecolor='lightgray')\n",
    "\n",
    "    gl = ax1.gridlines(ylocs=range(-90, 90, 30), draw_labels=True)\n",
    "    gl.ylocator = mpl.ticker.FixedLocator(range(-90,90,30))\n",
    "\n",
    "    # Add colorbar\n",
    "    cb = plt.colorbar(pm, ax=ax1, fraction=0.4, shrink=0.8, label='[yr]')\n",
    "    cb.set_label(label=data.units, size=14)\n",
    "    cb.ax.tick_params(labelsize=12)\n",
    "    plt.tight_layout()\n",
    "        \n",
    "    #fig, ax = plot_map2d(lon, lat, data2d.mean(dim='time'), proj='PlateCarree', norm=None, cmap=eval(mth_cmap))\n",
    "    ax1.coastlines(resolution='110m')\n",
    "    ax1.add_feature(cfeature.LAND, facecolor='lightgray')\n",
    "    plt.title(data2d.long_name)\n",
    "\n",
    "    # plot 1D timeseries\n",
    "    if mth_ts == 'mean':\n",
    "        data1d = (data2d*parea).mean(dim=['j', 'i'], keep_attrs=True)\n",
    "    elif mth_ts == 'sum':\n",
    "        data1d = (data2d*parea).sum(dim=['j', 'i'], keep_attrs=True)\n",
    "    else:\n",
    "        raise ValueError(f'Unsupported timeseries method: {mth_ts}')\n",
    "   \n",
    "    data1d.attrs = data2d.attrs.copy()\n",
    "\n",
    "    ax2 = fig.add_subplot(gs[1])\n",
    "    if data1d.sizes[\"time\"] < 2:\n",
    "        ax2.text(0.1, 0.8, (f\"global {mth_ts} value: {data1d.values[0]} {data1d.units}\"))\n",
    "        ax2.text(0.1, 0.6, (\"less than 2 time steps.\"))\n",
    "        ax2.text(0.1, 0.4, (\"skipping timeseries plot.\"))\n",
    "    else:\n",
    "        data1d.plot(ax=ax2)\n",
    "\n",
    "    plt.title(data1d.long_name)\n",
    "    plt.show()\n",
    "\n",
    "    del data, data2d, data1d\n",
    "    del ax1, ax2, pm, cb, gl, fig"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "cmip7validate-env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.13.3"
  },
  "papermill": {
   "default_parameters": {},
   "environment_variables": {},
   "input_path": "notebooks/ocean_2D.ipynb",
   "output_path": "ocean_2D.ipynb",
   "parameters": {
    "cmorout": "/scratch/yanchun/cmorout",
    "experiment_id": "piControl",
    "grid_label": "g999",
    "source_id": "DUMMY-MODEL",
    "variant_label": "r1i1p1f1",
    "version": "v20260601"
   },
   "version": "2.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}