get_subpolar_gyre_functions.py 6.09 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193

import numpy as np
import netCDF4 as nc
import matplotlib.pyplot as plt
from jupyterthemes import jtplot
import matplotlib as mpl
import pandas as pd
import datetime
from mpl_toolkits.axes_grid1 import make_axes_locatable
import cartopy.crs as ccrs
import os
import glob
import xarray as xr


# Global variables
# Region
lat_max = 61
lat_min = 46
lon_max = 360-20
lon_min = 360-55
lon_max_W = -20
lon_min_W = -55

# these seem to be the three latitude/longitude keys
lat_keys = ['lat', 'latitude', 'nav_lat']
lon_keys = ['lon', 'longitude', 'nav_lon']


# Functions to specify/edit the global variables
# set parameters
def set_boundaries(
    longitude_maximum = 360-20,
    longitude_minimum = 360-55,
    latitude_maximum = 61,
    latitude_minimum = 46):

    global lon_max, lon_min, lat_max, lat_min, lon_max_W, lon_min_W
    lon_max = longitude_maximum
    lon_min = longitude_minimum 
    lat_max = latitude_maximum
    lat_min = latitude_minimum
    lon_max_W = lon_max-360
    lon_min_W = lon_min-360

def set_lon_lat_keys(
    longitudes=['lon', 'longitude', 'nav_lon'],
    latitudes =['lat', 'latitude', 'nav_lat']):

    global lon_keys, lat_keys
    lon_keys = longitudes
    lat_keys = latitudes

def get_values():
    print("lon_min: "+str(lon_min))
    print("lon_max: "+str(lon_max))
    print("lat_min: "+str(lat_min))
    print("lat_max: "+str(lat_max))
    print("lon_keys: "+str(lon_keys))
    print("lat_keys: "+str(lat_keys))

# function to get the longitude / latitude keys from the model
def get_lon_lat_keys(data, location_keys):
    str_loc = [key for key in data.coords.keys() if key in location_keys]
    if len(str_loc) != 1:
        print('CAUTION: There exist the keys: ' + str(str_loc) + '. The first one is used!')
    return str_loc[0]

# let all datasets have the same representation of time
def time_to_year_month(data):
    return data.assign_coords(dict(time = [np.datetime64(d, 'M') for d in data.time.values]))

# function to extract subpolar gyre mean
def subpolargyre(data, str_lon, str_lat, print_info=False):
    mean_dims = [d for d in data.dims if d != "time"]
    if print_info:
        print("   The mean is taken over the dimensions ", mean_dims)
    if np.sum(data[str_lon].values<0) > 0:
        lon_max_tmp = lon_max_W
        lon_min_tmp = lon_min_W
    else:
        lon_max_tmp = lon_max
        lon_min_tmp = lon_min
    spg = (
        data
        .where(data[str_lon] >= lon_min_tmp)
        .where(data[str_lon] <= lon_max_tmp)
        .where(data[str_lat] >= lat_min)
        .where(data[str_lat] <= lat_max)
    )
    spg = spg.mean(dim = mean_dims)
    gl = data.mean(dim = mean_dims)
    spg = spg - gl
    # spg = spg.groupby('time.year').mean() # that would produce yearly means
    return spg

# wrapper to get the spg mean for an ensemble member
def get_subpolar_gyre(model, emember, print_info=False, paths_all_models=None):
    if paths_all_models is None:
        paths_all_models = get_paths_dict()
    paths = paths_all_models[model][emember]

    if len(paths) != 1:
        # there is one model where there are 2 paths for the same ensemble member...
        print("CAUTION: The model '" + model + "' provides " + str(len(paths)) + " paths for the ensemble member '" + emember + "'. It is NOT added to CMIP6_amoc_index.nc, this must be done manually!")
        return None
    if model == 'EC-Earth3-Veg' and emember == 'r10i1p1f1':
        print("   Ensemble-member '"+emember+ "' does not provide all years. It also does not occur in Matthew's data, so its not included here.")
        return None
    else:
        data = xr.open_mfdataset(paths[0], concat_dim="time").tos

        str_lon = get_lon_lat_keys(data, lon_keys)
        str_lat = get_lon_lat_keys(data, lat_keys)
        if print_info:
            print('   + Longitude: ' + str_lon)
            print('   + Latitude:  ' + str_lat)

        spg = subpolargyre(data, str_lon, str_lat, print_info)
        spg = time_to_year_month(spg)

        if print_info: print('  Ensemble-members added:')
        print('    - ' +emember)

        return spg

# make plots of world
def make_nice_plot(ds, str_lon, str_lat):
    fig, axis = plt.subplots(
        1, 1, subplot_kw=dict(projection=ccrs.Orthographic(central_longitude=320))
    )
    ds.plot(
        x = str_lon, 
        y = str_lat,
        ax = axis,
        transform = ccrs.PlateCarree()
    )
    axis.coastlines(linewidth=2,alpha=0.6)
    return fig

# make nice plots of spg area
def make_nice_plot_of_spg(data, str_lon, str_lat, model=None):
    if np.sum(data[str_lon].values<0) > 0:
        lon_max_tmp = lon_max_W
        lon_min_tmp = lon_min_W
    else:
        lon_max_tmp = lon_max
        lon_min_tmp = lon_min
    spg = (
        data
        .where(data[str_lon] >= lon_min_tmp)
        .where(data[str_lon] <= lon_max_tmp)
        .where(data[str_lat] >= lat_min)
        .where(data[str_lat] <= lat_max)
    )
    fig, axis = plt.subplots(
        1, 1, subplot_kw=dict(projection=ccrs.Orthographic(central_longitude=320))
    )
    spg.plot(
        x = str_lon, 
        y = str_lat,
        ax = axis,
        transform = ccrs.PlateCarree()
    )
    axis.coastlines(linewidth=2,alpha=0.6)
    if not model is None:
        fig.suptitle(model)
    return fig

def get_paths_dict():
    path_base = '/p/tmp/mayayami/SYNDA/data/CMIP6/CMIP/'
    institutes = os.listdir(path_base)
    paths_all_models = dict()

    for i in institutes:
        pi = path_base + i
        models = os.listdir(pi)
        for m in models:
            pim = pi + '/' + m + '/historical/'
            emembers = os.listdir(pim)
            tmp = dict()
            for v in emembers:
                paths = [x[0] for x in os.walk(pim+v)]
                data_paths = []
                for path in paths:
                    if glob.glob("{}/tos_Omon_*.nc".format(path)):
                        data_paths.append(glob.glob("{}/tos_Omon_*.nc".format(path)))
                if (len(data_paths)) != 1:
                    print("Several paths provided for model '"+m+"' and ensemble member '"+v+"'!")
                    print(data_paths)
                tmp[v] = data_paths
            paths_all_models[m] = tmp
    return paths_all_models