smashbox.plot.plot
1# -*- coding: utf-8 -*- 2import matplotlib.pyplot as plt 3import numpy as np 4import matplotlib 5from matplotlib.colors import ListedColormap 6import pandas as pd 7from pandas import DataFrame 8from smashbox.stats import stats 9from smashbox.tools import geo_toolbox 10import datetime 11from smash import Model 12 13import matplotlib.colors as mcolors 14from matplotlib import cm 15import colorsys 16from mpl_toolkits.axes_grid1 import make_axes_locatable 17 18import os 19import pandas as pd 20from smashbox.init.param import param 21 22from smashbox.tools import tools 23 24 25class plot_properties: 26 """Class which handle differents properties of the matplotlib plot function. 27 All attributes can be defined by the user and the object plot_properties can be 28 passed to any smashbox plot function. 29 """ 30 31 def __init__( 32 self, 33 ls="-", 34 lw=1.5, 35 marker="", 36 markersize=4, 37 color="black", 38 label="", 39 ): 40 self.ls = ls 41 """The style of the line (see matplotlib documentation)""" 42 self.lw = lw 43 """The linewidth of the line (see matplotlib documentation)""" 44 self.marker = marker 45 """The style of the marker (see matplotlib documentation)""" 46 self.markersize = markersize 47 """The size of the marker (see matplotlib documentation)""" 48 self.color = color 49 """The color of the line (see matplotlib documentation)""" 50 self.label = label 51 """The label of the line (see matplotlib documentation)""" 52 53 def update(self, **kwargs): 54 """Update the class attributes using kwarg (dictionnary)""" 55 for key, values in kwargs.items(): 56 setattr(self, key, values) 57 58 59class ax_properties: 60 """Class which handle differents properties of the matplotlib ax object 61 (see matplotlib documentation). All attributes can be defined by the user and the 62 object ax_properties can be passed to any smashbox plot function. 63 """ 64 65 def __init__( 66 self, 67 title: str = None, 68 xlabel: str = None, 69 ylabel: str = None, 70 clabel: str = None, 71 font_ratio: int = 1, 72 title_fontsize: int = 12, 73 label_fontsize: int = 10, 74 annotate_fontsize: int = 6, 75 grid: bool = True, 76 xscale: str | None = None, 77 yscale: str | None = None, 78 legend: bool = True, 79 legend_loc: str = None, 80 legend_fontsize: int = 8, 81 xtics_fontsize: int = 8, 82 ytics_fontsize: int = 8, 83 cmap: str | None = None, 84 xlim: tuple | list | None = (None, None), 85 ylim: tuple | list | None = (None, None), 86 xticklabels_rotation: int = 0, 87 barlabel_fontsize=6, 88 ): 89 90 self.title = title 91 """The title of the graphic""" 92 self.xlabel = xlabel 93 """The label of the x axis""" 94 self.ylabel = ylabel 95 """The label of the y axis""" 96 self.clabel = clabel 97 """The label of the colorbar""" 98 self.font_ratio = font_ratio 99 """Ratio of the global fontsize""" 100 self.title_fontsize = title_fontsize 101 """The fontsize of the title""" 102 self.label_fontsize = label_fontsize 103 """The label fontsize""" 104 self.annotate_fontsize = annotate_fontsize 105 """The annotation fontsize in plot""" 106 self.grid = grid 107 """Set the grid (boolean), default True""" 108 self.xscale = xscale 109 """Scale of the x axis (see matplotlib documentation)""" 110 self.yscale = yscale 111 """Scale of the x axis (see matplotlib documentation)""" 112 self.legend = legend 113 """Set the legend, boolean, default is True""" 114 self.legend_loc = legend_loc 115 """Localisation of the legend""" 116 self.legend_fontsize = legend_fontsize 117 """The fontsize of the legend""" 118 self.xtics_fontsize = xtics_fontsize 119 """The fontsize of the xtics""" 120 self.ytics_fontsize = ytics_fontsize 121 """The fontsize of the ytics""" 122 self.cmap = cmap 123 """The name of the used colormap""" 124 self.xlim = xlim 125 """The limit of the x axis, tuple or list, default is (None,None)""" 126 self.ylim = ylim 127 """The limit of the y axis, tuple or list, default is (None,None)""" 128 self.xticklabels_rotation = xticklabels_rotation 129 """Angle of the xtics labels, float, default is 0.""" 130 self.barlabel_fontsize = barlabel_fontsize * self.font_ratio 131 "Fontsize of the bar top label" 132 133 def update(self, **kwargs): 134 """Update the class attributes using kwarg (dictionnary)""" 135 for key, values in kwargs.items(): 136 setattr(self, key, values) 137 138 def change(self, figure): 139 """Apply change to the current figure `figure` 140 Parameter: 141 ---------- 142 figure : list or tuple 143 list of (fig, ax), figure and ax of a matplotlib subplot. 144 Return: 145 ------- 146 a tuple of the modified (fig,ax) 147 """ 148 fig, ax = figure 149 150 plt.rcParams.update(plt.rcParamsDefault) 151 152 plt.rcParams.update( 153 { 154 "axes.labelsize": self.label_fontsize * self.font_ratio, 155 "axes.titlesize": self.title_fontsize * self.font_ratio, 156 "legend.fontsize": self.legend_fontsize * self.font_ratio, 157 "figure.titlesize": self.title_fontsize * self.font_ratio, 158 "xtick.labelsize": self.xtics_fontsize * self.font_ratio, 159 "ytick.labelsize": self.ytics_fontsize * self.font_ratio, 160 } 161 ) 162 163 if self.title is not None: 164 165 ax.set_title(self.title, fontsize=self.title_fontsize * self.font_ratio) 166 167 if self.xlabel is not None: 168 ax.set_xlabel(self.xlabel, fontsize=self.label_fontsize * self.font_ratio) 169 170 if self.ylabel is not None: 171 ax.set_ylabel(self.ylabel, fontsize=self.label_fontsize * self.font_ratio) 172 173 if self.grid: 174 ax.grid(True, which="both", linestyle="--", alpha=0.5) 175 176 if self.xscale is not None: 177 ax.set_xscale(self.xscale) 178 179 if self.yscale is not None: 180 ax.set_xyscale(self.yscale) 181 182 if self.legend: 183 ax.legend( 184 loc=self.legend_loc, fontsize=self.legend_fontsize * self.font_ratio 185 ) 186 187 if self.cmap is not None: 188 plt.rc("image", cmap=self.cmap) 189 190 if self.ylim[0] is not None: 191 ax.set_ylim(bottom=self.ylim[0]) 192 193 if self.ylim[1] is not None: 194 ax.set_ylim(top=self.ylim[1]) 195 196 if self.xlim[0] is not None: 197 ax.set_xlim(left=self.xlim[0]) 198 199 if self.xlim[1] is not None: 200 ax.set_xlim(right=self.xlim[1]) 201 202 if self.xticklabels_rotation > 0: 203 ax.set_xticklabels( 204 ax.get_xticklabels(), rotation=self.xticklabels_rotation, ha="right" 205 ) 206 207 return fig, ax 208 209 210class fig_properties: 211 """Class which handle differents properties of the matplotlib fig object 212 (see matplotlib documentation). All attributes can be defined by the user and the 213 object ax_properties can be passed to any smashbox plot function. 214 """ 215 216 def __init__( 217 self, 218 figname=None, 219 xsize=8, 220 ysize=6, 221 transparent=False, 222 dpi=160, 223 font_ratio=1, 224 bbox_inches="tight", 225 ): 226 227 self.figname = figname 228 """Path to the figure name to be saved""" 229 self.xsize = xsize 230 """Width of the figure in inch""" 231 self.ysize = ysize 232 """Height of the figure in inch""" 233 self.transparent = transparent 234 """Use transparency when exporting the figure, default is False""" 235 self.dpi = dpi 236 """Résolution (dpi), int, default is 80""" 237 self.font_ratio = font_ratio 238 """Global font ratio""" 239 self.bbox_inches = bbox_inches 240 """Constraint of the boundingbox of each ax (see matplotlib docuentation)""" 241 242 def update(self, **kwargs): 243 """Update the class attributes using kwarg (dictionnary)""" 244 245 for key, values in kwargs.items(): 246 setattr(self, key, values) 247 248 def change(self, figure): 249 """Apply change to the current figure `figure` 250 Parameter: 251 ---------- 252 figure : list or tuple 253 list of (fig, ax), figure and ax of a matplotlib subplot. 254 Return: 255 ------- 256 a tuple of the modified (fig,ax) 257 """ 258 259 fig, ax = figure 260 261 fig.set_figheight(self.ysize) 262 fig.set_figwidth(self.xsize) 263 264 plt.rc( 265 "font", size=plt.rcParams["font.size"] * self.font_ratio 266 ) # controls default text sizes 267 plt.rc( 268 "axes", titlesize=plt.rcParams["axes.titlesize"] * self.font_ratio 269 ) # fontsize of the axes title 270 plt.rc( 271 "axes", labelsize=plt.rcParams["axes.labelsize"] * self.font_ratio 272 ) # fontsize of the x and y labels 273 plt.rc( 274 "xtick", labelsize=plt.rcParams["xtick.labelsize"] * self.font_ratio 275 ) # fontsize of the tick labels 276 plt.rc( 277 "ytick", labelsize=plt.rcParams["ytick.labelsize"] * self.font_ratio 278 ) # fontsize of the tick labels 279 plt.rc( 280 "legend", fontsize=plt.rcParams["legend.fontsize"] * self.font_ratio 281 ) # legend fontsize 282 plt.rc( 283 "figure", titlesize=plt.rcParams["figure.titlesize"] * self.font_ratio 284 ) # fontsize of the figure title 285 286 if self.figname is not None: 287 288 head_path, basename = os.path.split(self.figname) 289 290 if len(head_path) > 0 and not os.path.exists(head_path): 291 os.makedirs(head_path) 292 293 fig.savefig( 294 self.figname, 295 transparent=self.transparent, 296 dpi=self.dpi, 297 bbox_inches=self.bbox_inches, 298 ) 299 300 return fig, ax 301 302 303@tools.autocast_args 304def save_figure( 305 fig=None, figname="myfigure", xsize=8, ysize=6, transparent=False, dpi=80 306): 307 """ 308 Save a figure. 309 Parameters: 310 ----------- 311 fig: fig object returned by matplotlib.subplot 312 the figure to save 313 figname: str 314 Path to the figure 315 xsize: int 316 width of the figure in inch 317 ysize: int 318 height of the figure in inch 319 transparent: bool, default is False 320 use transparency 321 dpi : int 322 resolution of the figure, default is 80 323 """ 324 fig.set_size_inches(xsize, ysize, forward=True) 325 fig.savefig(figname, transparent=transparent, dpi=dpi, bbox_inches="tight") 326 327 328def generate_palette(base_color, n, variation="hue"): 329 """ 330 Generate a palette of colors from a base color. 331 Parameter: 332 --------- 333 base_color: str 334 matplotlib color string 335 n: int 336 number of color to generate 337 variation: 'hue' | 'brightness' 338 how to generate the color palette, by changing the hue or the brighness of the base color. 339 Return: a list of colors 340 """ 341 # Convertir la couleur de base en format RGB normalisé (0-1) 342 rgb = mcolors.to_rgb(base_color) 343 344 # Convertir en HSV 345 h, s, v = colorsys.rgb_to_hsv(*rgb) 346 347 # Générer n couleurs en modifiant la teinte ou valeur 348 palette = [] 349 for i in range(n): 350 new_h = h 351 if variation == "hue": 352 new_h = (h + i / n) % 1.0 # cycle dans le cercle chromatique 353 new_v = v 354 if variation == "brightness": 355 new_v = max(0.1, min(1.0, v * (0.5 + i / (2 * n)))) # éviter le noir complet 356 357 new_rgb = colorsys.hsv_to_rgb(new_h, s, new_v) 358 palette.append(new_rgb) 359 360 return palette 361 362 363@tools.autocast_args 364def plot_chro( 365 data: np.ndarray = np.zeros(shape=(1, 10)), 366 t_axis: int = 1, 367 outlets_name: list | tuple = [], 368 columns: list | tuple = [], 369 dt: float = 0.0, 370 xtics: list | tuple = [], 371 date_range: list = None, 372 figure=None, 373 ax_settings: dict | ax_properties = ax_properties(), 374 fig_settings: dict | fig_properties = fig_properties(), 375 plot_settings: dict | plot_properties = plot_properties(), 376): 377 """ 378 Plot a temporal chonic of values 379 Parameters: 380 ----------- 381 data: np.ndarray of dimenion 2. 382 data to plot as a matrix of 2 dimension. 383 t_axis : int 384 the axis of the time in data, default is 1 385 outlets_name: list 386 the list of the outlets name 387 columns : list 388 the column to be plotted in t_axis direction 389 dt: float 390 the timestep 391 xtics : list 392 list of date for the xtics. The format must be automatically read by numpy.Datetime 393 date_range: list 394 list of [date_start, date_end, timedelta] to generate the xtics 395 figure: tuple 396 input figure as (fig,ax) to add a new curve 397 ax_settings: dict or class ax_properties 398 object or dict with any attribute of class ax_properties 399 fig_settings: dict or class ax_properties 400 object or dict with any attribute of class fig_settings 401 plot_settings: dict or class plot_properties 402 object or dict with any attribute of class plot_settings 403 404 """ 405 406 if isinstance(ax_settings, dict): 407 ax_settings = ax_properties(**ax_settings) 408 else: 409 ax_settings = ax_properties(**ax_settings.__dict__) 410 411 if isinstance(fig_settings, dict): 412 fig_settings = fig_properties(**fig_settings) 413 else: 414 fig_settings = fig_properties(**fig_settings.__dict__) 415 416 if isinstance(plot_settings, dict): 417 plot_settings = plot_properties(**plot_settings) 418 else: 419 plot_settings = plot_properties(**plot_settings.__dict__) 420 421 data = np.moveaxis(data, t_axis, 0) 422 423 if figure is not None: 424 fig, ax = figure 425 else: 426 fig, ax = plt.subplots() 427 428 fig, ax = ax_settings.change(figure=(fig, ax)) 429 430 if len(xtics) == 0: 431 xtics = np.arange(0, data.shape[0]) 432 if dt > 0: 433 xtics = xtics * dt 434 else: 435 for i in range(len(xtics)): 436 xtics[i] = np.datetime64(xtics[i]) 437 438 if date_range is not None: 439 if len(date_range) != 3: 440 raise ValueError( 441 "date_range must have a length of 3: [date_start, date_end, step (s)]" 442 ) 443 xtics = np.arange( 444 np.datetime64(date_range[0]), 445 np.datetime64(date_range[1] + pd.Timedelta(seconds=int(date_range[2]))), 446 np.timedelta64(int(date_range[2]), "s"), 447 ) 448 449 if len(columns) > 0: 450 451 args = plot_settings.__dict__.copy() 452 print(args) 453 del args["label"] 454 del args["color"] 455 456 palette = generate_palette(plot_settings.color, len(columns)) 457 458 for i in columns: 459 ax.plot(xtics[:], data[:, i], **args, label=outlets_name[i], color=palette[i]) 460 else: 461 ax.plot(xtics[:], data[:, 0], **plot_settings.__dict__) 462 463 fig, ax = ax_settings.change(figure=(fig, ax)) 464 fig, ax = fig_settings.change(figure=(fig, ax)) 465 466 return fig, ax 467 468 469def plot_hydrograph( 470 model: Model | None = None, 471 columns: list | tuple = [], 472 outlets_name: list | tuple = [], 473 plot_rainfall: bool = True, 474 plot_qobs: bool = True, 475 figure: list | tuple | None = None, 476 ax_settings: dict | ax_properties = {}, 477 fig_settings: dict | fig_properties = {}, 478 plot_settings_sim: dict | plot_properties = {}, 479 plot_settings_obs: dict | plot_properties = {}, 480): 481 """ 482 Plot an hydrograph from a smash model 483 Parameters: 484 ----------- 485 model: a smash model object 486 a smash model object 487 outlets_name: list 488 the list of the outlets name 489 columns : list 490 the column to be plotted in t_axis direction 491 figure: tuple 492 input figure as (fig,ax) to add a new curve 493 ax_settings: dict or class ax_properties 494 object or dict with any attribute of class ax_properties 495 fig_settings: dict or class ax_properties 496 object or dict with any attribute of class fig_settings 497 plot_settings_sim: dict or class ax_properties 498 object or dict with any attribute of class plot_settings.Control the simulated curve 499 plot_settings_obs: dict or class ax_properties 500 object or dict with any attribute of class plot_settings. Control the observed curve 501 502 """ 503 if model is None: 504 raise ValueError("Input smash model object is None.") 505 506 if isinstance(ax_settings, dict): 507 default_ax_settings = ax_properties( 508 xlabel="Time", ylabel="discharges m^3/s", xtics_fontsize=10, ytics_fontsize=10 509 ) 510 default_ax_settings.update(**ax_settings) 511 else: 512 default_ax_settings = ax_properties(**ax_settings.__dict__) 513 514 if isinstance(fig_settings, dict): 515 fig_settings = fig_properties(**fig_settings) 516 else: 517 fig_settings = fig_properties(**fig_settings.__dict__) 518 519 if isinstance(plot_settings_sim, dict): 520 default_plot_settings_sim = plot_properties( 521 ls="-", 522 lw="2", 523 marker="", 524 markersize=4, 525 color="blue", 526 label="Sim", 527 ) 528 default_plot_settings_sim.update(**plot_settings_sim) 529 else: 530 default_plot_settings_sim = plot_properties(**plot_settings_sim.__dict__) 531 532 # default color for multi curves: same color but different line type 533 if len(columns) >= 2: 534 color = "blue" 535 else: 536 color = "black" 537 538 if isinstance(plot_settings_obs, dict): 539 default_plot_settings_obs = plot_properties( 540 ls="--", 541 lw="1.5", 542 marker="", 543 markersize=4, 544 color=color, 545 label="Obs", 546 ) 547 default_plot_settings_obs.update(**plot_settings_obs) 548 else: 549 default_plot_settings_obs = plot_properties() 550 551 # manage date here 552 date_deb = datetime.datetime.fromisoformat( 553 model.setup.start_time 554 ) + datetime.timedelta(seconds=int(model.setup.dt)) 555 date_end = datetime.datetime.fromisoformat(model.setup.end_time) 556 date_range = [date_deb, date_end, model.setup.dt] 557 558 if figure is None: 559 if plot_rainfall: 560 fig, (ax2, ax1) = plt.subplots(2, 1, height_ratios=[1, 4]) 561 fig.subplots_adjust(hspace=0) 562 figure = [fig, ax1, ax2] 563 else: 564 fig, ax2 = plt.subplots() 565 figure = [fig, ax1] 566 else: 567 if plot_rainfall: 568 fig = figure[0] 569 ax1 = figure[1] 570 ax2 = figure[2] 571 else: 572 fig = figure[0] 573 ax1 = figure[1] 574 575 fig, ax = default_ax_settings.change(figure=(fig, ax1)) 576 577 if plot_qobs: 578 fig, ax1 = plot_chro( 579 np.where(model.response_data.q<0, np.nan, model.response_data.q), 580 date_range=date_range, 581 columns=columns, 582 outlets_name=["obs_" + name for name in outlets_name], 583 figure=(fig, ax1), 584 ax_settings=default_ax_settings, 585 fig_settings=fig_settings, 586 plot_settings=default_plot_settings_obs, 587 ) 588 589 fig, ax1 = plot_chro( 590 model.response.q, 591 date_range=date_range, 592 columns=columns, 593 outlets_name=["sim_" + name for name in outlets_name], 594 figure=(fig, ax1), 595 ax_settings=default_ax_settings, 596 fig_settings=fig_settings, 597 plot_settings=default_plot_settings_sim, 598 ) 599 600 xtics = np.arange( 601 np.datetime64(date_range[0]), 602 np.datetime64(date_range[1] + datetime.timedelta(seconds=int(date_range[2]))), 603 np.timedelta64(int(date_range[2]), "s"), 604 ) 605 606 axes_list = [ax1] 607 608 if plot_rainfall: 609 610 if len(columns) > 0: 611 col = columns[0] 612 else: 613 col = 0 614 615 ax2.bar( 616 xtics[:], 617 model.atmos_data.mean_prcp[col, :], 618 label="Average rainfall (mm)", 619 width=np.timedelta64(int(date_range[2]), "s"), 620 color="blue", 621 ) 622 623 ax2.invert_yaxis() 624 ax2.grid(alpha=0.7, ls="--") 625 ax2.get_xaxis().set_visible(False) 626 ax2.set_ylim(bottom=1.2 * max(model.atmos_data.mean_prcp[0, :]), top=0.0) 627 ax2.set_ylabel("Average rainfall (mm)") 628 629 axes_list.append(ax2) 630 631 fig, ax = fig_settings.change(figure=(fig, tuple(axes_list))) 632 633 return fig, ax 634 635 636def plot_catchment_surface_error( 637 mesh: dict = None, 638 ax_settings: dict | ax_properties = {}, 639 fig_settings: dict | fig_properties = {}, 640): 641 """ 642 Plot the misfit criteria between the simulated and observed discharges. 643 Parameters: 644 ----------- 645 values: np.ndarray 646 The result of the discharge misfit for all outlets. 647 names: np.ndarray 648 Outlets name or code stored in an np.ndarray. 649 columns: list | None 650 Columns of the np.ndarray to plot 651 misfit: str 652 Criteria to plot. choice are ['nse', 'nnse', 'rmse', 'nrmse', 'se', 'kge']" 653 figure: tuple 654 input figure as (fig,ax) to add a new curve 655 ax_settings: dict or class ax_properties 656 object or dict with any attribute of class ax_properties 657 fig_settings: dict or class ax_properties 658 object or dict with any attribute of class fig_settings 659 """ 660 661 if isinstance(ax_settings, dict): 662 default_ax_settings = ax_properties( 663 title="Catchment surface error (Ssim-Sobs)/Sobs *100", 664 ylabel="Surface error %", 665 xlabel="Outlets", 666 xticklabels_rotation=45, 667 xtics_fontsize=6, 668 ) 669 default_ax_settings.update(**ax_settings) 670 else: 671 default_ax_settings = ax_properties(**ax_settings.__dict__) 672 673 if isinstance(fig_settings, dict): 674 fig_settings = fig_properties(**fig_settings) 675 else: 676 fig_settings = fig_properties(**fig_settings.__dict__) 677 678 if len(mesh["code"]) == 0: 679 print("Cannot plot this, the mesh has no gauge !") 680 return None, None 681 682 plt.rcParams.update(plt.rcParamsDefault) 683 fig, ax = plt.subplots() 684 fig, ax = default_ax_settings.change(figure=(fig, ax)) 685 686 surface_error = (mesh["area_dln"] - mesh["area"]) / mesh["area"] * 100 687 688 fig, ax = default_ax_settings.change(figure=(fig, ax)) 689 bar_container = ax.bar( 690 mesh["code"], surface_error, color="grey", tick_label=mesh["code"] 691 ) 692 693 ax.bar_label( 694 bar_container, 695 fmt=lambda x: f"{x:.2f}", 696 fontsize=default_ax_settings.barlabel_fontsize, 697 ) 698 699 fig, ax = default_ax_settings.change(figure=(fig, ax)) 700 fig, ax = fig_settings.change(figure=(fig, ax)) 701 702 return fig, ax 703 704 705def plot_catchment_surface_consistency( 706 mesh: dict = None, 707 label: bool = True, 708 ax_settings: dict | ax_properties = {}, 709 fig_settings: dict | fig_properties = {}, 710 plot_settings: dict | plot_properties = {}, 711): 712 """ 713 Plot the modeled surface vs the observed surface 714 Parameters: 715 ----------- 716 mesh: dict, optional 717 The mesh of the Smash model, defaults to None 718 ax_settings: dict or class ax_properties 719 object or dict with any attribute of class ax_properties 720 fig_settings: dict or class ax_properties 721 object or dict with any attribute of class fig_settings 722 plot_settings: dict or class plot_properties 723 object or dict with any attribute of class plot_settings.Control the simulated curve 724 """ 725 726 if isinstance(ax_settings, dict): 727 default_ax_settings = ax_properties( 728 title="Modeled and observed surface consistency", 729 ylabel="Modeled surface", 730 xlabel="Observed surface", 731 ) 732 default_ax_settings.update(**ax_settings) 733 else: 734 default_ax_settings = ax_properties(**ax_settings.__dict__) 735 736 if isinstance(fig_settings, dict): 737 fig_settings = fig_properties(**fig_settings) 738 else: 739 fig_settings = fig_properties(**fig_settings.__dict__) 740 741 if isinstance(plot_settings, dict): 742 default_plot_settings = plot_properties( 743 marker="+", 744 markersize=12, 745 color="blue", 746 ) 747 default_plot_settings.update(**plot_settings) 748 else: 749 default_plot_settings = plot_properties(**plot_settings.__dict__) 750 751 if len(mesh["code"]) == 0: 752 print("Cannot plot this, the mesh has no gauge !") 753 return None, None 754 755 surface_model = mesh["area_dln"] / 1000.0**2.0 756 surface_obs = mesh["area"] / 1000.0**2.0 757 758 plt.rcParams.update(plt.rcParamsDefault) 759 fig, ax = plt.subplots() 760 fig, ax = default_ax_settings.change(figure=(fig, ax)) 761 762 ax.plot( 763 surface_obs, 764 surface_model, 765 markersize=default_plot_settings.markersize, 766 marker=default_plot_settings.marker, 767 color=default_plot_settings.color, 768 linestyle="None", 769 ) 770 ax.plot( 771 np.linspace(min(surface_obs), max(surface_obs), 10), 772 np.linspace(min(surface_obs), max(surface_obs), 10), 773 linewidth=2, 774 color="grey", 775 ) 776 777 if label: 778 ha = ("left", "right") 779 for i, label in enumerate(mesh["code"]): 780 ax.annotate( 781 label, # this is the text 782 ( 783 surface_obs[i], 784 surface_model[i], 785 ), # these are the coordinates to position the label 786 textcoords="data", # how to position the text 787 xytext=( 788 surface_obs[i], 789 surface_model[i], 790 ), # distance from text to points (x,y) 791 ha=ha[i % 2], # horizontal alignment can be left, right or center 792 color="red", 793 fontsize=default_ax_settings.annotate_fontsize, 794 ) 795 796 ax.set(xlabel=default_ax_settings.xlabel, ylabel=default_ax_settings.ylabel) 797 798 fig, ax = fig_settings.change(figure=(fig, ax)) 799 800 return fig, ax 801 802 803def plot_mesh( 804 mesh: dict = None, 805 coef_hydro: float = 99.0, 806 catchment_polygon: None | DataFrame = None, 807 ax_settings: dict | ax_properties = {}, 808 fig_settings: dict | fig_properties = {}, 809): 810 """ 811 Plot the mesh of a smash model 812 Parameters: 813 ----------- 814 mesh: a smash mesh as dictionary 815 a smash model object 816 coef_hydro: float 817 the coefficient to colorize the hydrographic network accodring the cumulative 818 surface. default is 99% so that 99% of the cell will be hidden. 819 ax_settings: dict or class ax_properties 820 object or dict with any attribute of class ax_properties 821 fig_settings: dict or class ax_properties 822 object or dict with any attribute of class fig_settings 823 """ 824 825 if mesh is not None: 826 if isinstance(mesh, dict): 827 pass 828 else: 829 raise ValueError("mesh must be a dict") 830 else: 831 raise ValueError( 832 "model or mesh are mandatory and must be a dict or a smash Model object" 833 ) 834 835 if isinstance(ax_settings, dict): 836 default_ax_settings = ax_properties( 837 title="Mesh of the Smash model", 838 xlabel="x_coords", 839 ylabel="y_coords", 840 ) 841 default_ax_settings.update(**ax_settings) 842 else: 843 default_ax_settings = ax_properties(**ax_settings.__dict__) 844 845 if isinstance(fig_settings, dict): 846 fig_settings = fig_properties(**fig_settings) 847 else: 848 fig_settings = fig_properties(**fig_settings.__dict__) 849 850 # mesh["active_cell"] 851 gauge = mesh["gauge_pos"] 852 stations = mesh["code"] 853 flow_acc = mesh["flwacc"] 854 na = mesh["active_cell"] == 0 855 856 flow_accum_bv = np.where(na, 0.0, flow_acc.data / 1000000.0) 857 surfmin = (1.0 - coef_hydro / 100.0) * np.max(flow_accum_bv) 858 mask_flow = flow_accum_bv < surfmin 859 flow_plot = np.where(mask_flow, np.nan, flow_accum_bv) 860 flow_plot = np.where(na, np.nan, flow_plot) 861 862 plt.rcParams.update(plt.rcParamsDefault) 863 fig, ax = plt.subplots() 864 fig, ax = default_ax_settings.change(figure=(fig, ax)) 865 866 bbox = geo_toolbox.get_bbox_from_smash_mesh(mesh) 867 extent = (bbox["left"], bbox["right"], bbox["bottom"], bbox["top"]) 868 869 active_cell = np.where(na, np.nan, mesh["active_cell"]) 870 cmap = ListedColormap(["lightgray"]) 871 ax.imshow(active_cell, cmap=cmap, extent=extent) 872 873 myblues = matplotlib.colormaps["Blues"] 874 cmp = ListedColormap(myblues(np.linspace(0.30, 1.0, 265))) 875 im = ax.imshow(flow_plot, cmap=cmp, extent=extent) 876 877 if catchment_polygon is not None: 878 # catchment_polygon = gpd.read_file(outlets_shapefile) 879 catchment_polygon.plot(ax=ax, facecolor="none", edgecolor="black") 880 881 # create an axes on the right side of ax. The width of cax will be 5% 882 # of ax and the padding between cax and ax will be fixed at 0.05 inch. 883 divider = make_axes_locatable(ax) 884 cax = divider.append_axes("right", size="5%", pad=0.05) 885 886 fig.colorbar( 887 im, cmap="Blues", ax=ax, label="Cumulated surface (km²)", shrink=0.75, cax=cax 888 ) 889 890 pos_y = -5 891 ha = "right" 892 for i in range(len(stations)): 893 if pos_y > 0: 894 pos_y = -10 895 else: 896 pos_y = 5 897 # pos_y=-1*pos_y 898 899 if ha == "right": 900 ha = "left" 901 pos_x = 5 902 else: 903 ha = "right" 904 # pos_x = -5 905 906 coord = geo_toolbox.rowcol_to_xy( 907 gauge[i][0], 908 gauge[i][1], 909 mesh["xmin"], 910 mesh["ymax"], 911 mesh["xres"], 912 mesh["yres"], 913 ) + np.array( 914 [ 915 mesh["dx"][gauge[i][0], gauge[i][1]] / 2, 916 -mesh["dx"][gauge[i][0], gauge[i][1]] / 2, 917 ] 918 ) 919 920 code = stations[i] 921 ax.plot(coord[0], coord[1], color="green", marker="o", markersize=6) 922 ax.annotate( 923 code, # this is the text 924 # these are the coordinates to position the label 925 (coord[0], coord[1]), 926 # textcoords="offset points", # how to position the text 927 # xytext=(pos_x, pos_y), # distance from text to points (x,y) 928 textcoords="data", # how to position the text 929 xytext=(coord[0], coord[1]), # distance from text to points (x,y) 930 ha=ha, # horizontal alignment can be left, right or center 931 color="red", 932 fontsize=10, 933 ) 934 935 fig, ax = default_ax_settings.change(figure=(fig, ax)) 936 937 fig, ax = fig_settings.change(figure=(fig, ax)) 938 939 return fig, ax 940 941 942def plot_xy_quantile( 943 res_quantile, 944 X, 945 Y, 946 res_quantile_obs=None, 947 gauge_pos=None, 948 figure=None, 949 ax_settings: dict | ax_properties = {}, 950 fig_settings: dict | fig_properties = {}, 951 plot_settings: dict | plot_properties = {}, 952): 953 """ 954 Plot the discharges quantiles fitting at X,Y coordinates. 955 Parameters: 956 ----------- 957 res_quantile: dict 958 The result of the discharge quantile computation. 959 res_quantile_obs: dict 960 The results of the observed discharges quantile. res_quantile_obs is a dict and must be computed by the function smashbox.stats.stats.quantile_obs() 961 gauge_pos: int 962 gauge_pos is the index of gauge in the Smash mesh for which the quantile_discharge are provided to the function. 963 X: int 964 Coordinates of the pixel in the row directions (X means row) 965 Y: int 966 Coordinates of the pixel in the column directions (Y means column) 967 figure: tuple 968 input figure as (fig,ax) to add a new curve 969 ax_settings: dict or class ax_properties 970 object or dict with any attribute of class ax_properties 971 fig_settings: dict or class ax_properties 972 object or dict with any attribute of class fig_settings 973 """ 974 if isinstance(ax_settings, dict): 975 default_ax_settings = ax_properties( 976 xscale="log", 977 xlabel=f"Return period (*{res_quantile['chunk_size']} days)", 978 ylabel="Discharges (m³/s)", 979 grid=True, 980 legend=True, 981 ) 982 default_ax_settings.update(**ax_settings) 983 else: 984 default_ax_settings = ax_properties(**ax_settings.__dict__) 985 986 if isinstance(fig_settings, dict): 987 fig_settings = fig_properties(**fig_settings) 988 else: 989 fig_settings = fig_properties(**fig_settings.__dict__) 990 991 if isinstance(plot_settings, dict): 992 default_plot_settings = plot_properties(markersize=10) 993 default_plot_settings.update(**plot_settings) 994 else: 995 default_plot_settings = plot_properties(**plot_settings.__dict__) 996 997 quantile = res_quantile["Q_th"][X, Y] 998 maxima = res_quantile["maxima"][X, Y] 999 T_emp = res_quantile["T_emp"] 1000 loc = res_quantile["fit_loc"][X, Y] 1001 scale = res_quantile["fit_scale"][X, Y] 1002 shape = res_quantile["fit_shape"][X, Y] 1003 fit = res_quantile["fit"] 1004 1005 sorted_data = np.sort(maxima) 1006 1007 plt.rcParams.update(plt.rcParamsDefault) 1008 if figure is None: 1009 fig, ax = plt.subplots() 1010 else: 1011 fig, ax = figure 1012 1013 fig, ax = default_ax_settings.change(figure=(fig, ax)) 1014 1015 if res_quantile_obs is not None and len(res_quantile_obs.keys()) > 0: 1016 if gauge_pos is None: 1017 raise ValueError( 1018 "gauge_pos is None. gauge_pos argument must be an integer corresponding to the gauge index." 1019 ) 1020 maxima_obs = res_quantile_obs["maxima"][gauge_pos, :] 1021 T_emp_obs = res_quantile_obs["Temp"][gauge_pos, :] 1022 1023 ax.plot( 1024 T_emp_obs, 1025 maxima_obs, 1026 "o", 1027 label="Observed", 1028 color="black", 1029 markersize=default_plot_settings.markersize, 1030 ) 1031 1032 ax.plot( 1033 T_emp, 1034 sorted_data, 1035 "o", 1036 label="Empirical", 1037 markersize=default_plot_settings.markersize, 1038 ) 1039 1040 ax.plot( 1041 res_quantile["T"], 1042 quantile, 1043 "x", 1044 label="Theorical", 1045 markersize=default_plot_settings.markersize, 1046 ) 1047 1048 Trange = np.linspace(1.1, np.max(res_quantile["T"]), 50) 1049 1050 if fit == "gumbel": 1051 ax.plot( 1052 Trange, 1053 [stats.quantile_gumbel(T, loc, scale) for T in Trange], 1054 "r--", 1055 label=f"{fit} fitted", 1056 lw=default_plot_settings.lw, 1057 ) 1058 1059 if fit == "gev": 1060 ax.plot( 1061 Trange, 1062 [stats.quantile_gev(T, shape, loc, scale) for T in Trange], 1063 "r--", 1064 label=f"{fit} fitted", 1065 lw=default_plot_settings.lw, 1066 ) 1067 1068 if "Umax" in res_quantile.keys() and "Umin" in res_quantile.keys(): 1069 if res_quantile["Umax"] is not None and res_quantile["Umin"] is not None: 1070 ax.plot( 1071 res_quantile["T"], 1072 res_quantile["Umax"][X, Y], 1073 "r--", 1074 label="Uncertainties (max)", 1075 color="grey", 1076 lw=default_plot_settings.lw, 1077 ) 1078 ax.plot( 1079 res_quantile["T"], 1080 res_quantile["Umin"][X, Y], 1081 "r--", 1082 label="Uncertainties (min)", 1083 color="grey", 1084 lw=default_plot_settings.lw, 1085 ) 1086 1087 fig, ax = default_ax_settings.change(figure=(fig, ax)) 1088 fig, ax = fig_settings.change(figure=(fig, ax)) 1089 1090 return fig, ax 1091 1092 1093def plot_image( 1094 matrice=np.zeros(shape=(2, 2)), 1095 bbox=None, 1096 vmin=None, 1097 vmax=None, 1098 mask=None, 1099 extend=None, 1100 catchment_polygon=None, 1101 figure=None, 1102 ax_settings: dict | ax_properties = {}, 1103 fig_settings: dict | fig_properties = {}, 1104): 1105 """ 1106 Function for plotting a matrix as an image 1107 1108 Parameters 1109 ---------- 1110 matrice : numpy array 1111 Matrix to be plotted 1112 bbox : list 1113 ["left","right","bottom","top"] bouding box to put x and y coordinates instead 1114 of the shape of the matrix 1115 vmin: real, 1116 minimum z value 1117 vmax: real, 1118 maximum z value 1119 mask: integer, matrix, shape of matice, contain 0 for pixels that should not be plotted 1120 catchment_polygon: dataframe containing some polygon to be plotted. 1121 Ideally it must contain the boundaries of the catchment as a polygon from a shp file 1122 read by geopanda. 1123 figure: tuple 1124 input figure as (fig,ax) to add a new curve 1125 ax_settings: dict or class ax_properties 1126 object or dict with any attribute of class ax_properties 1127 fig_settings: dict or class ax_properties 1128 object or dict with any attribute of class fig_settings 1129 1130 Examples 1131 ---------- 1132 smash.utils.plot_image(mesh_france['drained_area'],bbox=bbox,title="Surfaces 1133 drainées",xlabel="Longitude",ylabel="Latitude",zlabel="Surfaces drainées 1134 km^2",vmin=0.0,vmax=1000,mask=mesh_france['global_active_cell']) 1135 1136 """ 1137 1138 if isinstance(ax_settings, dict): 1139 ax_settings = ax_properties(**ax_settings) 1140 else: 1141 ax_settings = ax_properties(**ax_settings.__dict__) 1142 1143 if isinstance(fig_settings, dict): 1144 fig_settings = fig_properties(**fig_settings) 1145 else: 1146 fig_settings = fig_properties(**fig_settings.__dict__) 1147 1148 matrice = np.float32(matrice) 1149 1150 if bbox is not None: 1151 extent = [ 1152 bbox["left"], 1153 bbox["right"], 1154 bbox["bottom"], 1155 bbox["top"], 1156 ] 1157 else: 1158 extent = None 1159 1160 if mask is not None: 1161 matrice[np.where(mask == 0)] = np.nan 1162 1163 plt.rcParams.update(plt.rcParamsDefault) 1164 if figure is None: 1165 fig, ax = plt.subplots() 1166 else: 1167 fig, ax = figure 1168 1169 if vmax is None: 1170 vmax=np.max(matrice) 1171 if vmin is None: 1172 vmin=np.min(matrice) 1173 1174 fig, ax = ax_settings.change(figure=(fig, ax)) 1175 1176 im = ax.imshow(matrice, extent=extent, vmin=vmin, vmax=vmax, cmap=ax_settings.cmap) 1177 1178 if catchment_polygon is not None: 1179 catchment_polygon.plot(ax=ax, facecolor="none", edgecolor="black") 1180 1181 # create an axes on the right side of ax. The width of cax will be 5% 1182 # of ax and the padding between cax and ax will be fixed at 0.05 inch. 1183 divider = make_axes_locatable(ax) 1184 cax = divider.append_axes("right", size="5%", pad=0.05) 1185 1186 plt.colorbar(im, label=ax_settings.clabel, cax=cax) 1187 1188 fig, ax = ax_settings.change(figure=(fig, ax)) 1189 fig, ax = fig_settings.change(figure=(fig, ax)) 1190 1191 return (fig, ax) 1192 1193 1194def plot_misfit( 1195 values: np.ndarray = [], 1196 names: np.ndarray = [], 1197 columns: list | None = None, 1198 misfit: str = "nse", 1199 figure: list | tuple | None = None, 1200 ax_settings: dict | ax_properties = {}, 1201 fig_settings: dict | fig_properties = {}, 1202): 1203 """ 1204 Plot the misfit criteria between the simulated and observed discharges. 1205 Parameters: 1206 ----------- 1207 values: np.ndarray 1208 The result of the discharge misfit for all outlets. 1209 names: np.ndarray 1210 Outlets name or code stored in an np.ndarray. 1211 columns: list | None 1212 Columns of the np.ndarray to plot 1213 misfit: str 1214 Criteria to plot. choice are ['nse', 'nnse', 'rmse', 'nrmse', 'se', 'kge']" 1215 figure: tuple 1216 input figure as (fig,ax) to add a new curve 1217 ax_settings: dict or class ax_properties 1218 object or dict with any attribute of class ax_properties 1219 fig_settings: dict or class ax_properties 1220 object or dict with any attribute of class fig_settings 1221 """ 1222 1223 if isinstance(ax_settings, dict): 1224 default_ax_settings = ax_properties( 1225 ylabel=f"{misfit} criteria", 1226 xlabel="Gauges stations", 1227 grid=True, 1228 legend=True, 1229 xticklabels_rotation=45, 1230 xtics_fontsize=8, 1231 ) 1232 default_ax_settings.update(**ax_settings) 1233 else: 1234 default_ax_settings = ax_properties(**ax_settings.__dict__) 1235 1236 if isinstance(fig_settings, dict): 1237 fig_settings = fig_properties(**fig_settings) 1238 else: 1239 fig_settings = fig_properties(**fig_settings.__dict__) 1240 1241 if len(names) == 0: 1242 names = np.arange(len(values)) 1243 1244 if columns is not None: 1245 values = values[columns] 1246 names = names[columns] 1247 1248 # remove nan from plot 1249 columns = list(np.isnan(values) == False) 1250 # print(columns) 1251 if len(columns) > 0: 1252 values = values[columns] 1253 names = names[columns] 1254 1255 if figure is None: 1256 fig, ax = plt.subplots() 1257 else: 1258 fig, ax = figure 1259 1260 fig, ax = default_ax_settings.change(figure=(fig, ax)) 1261 bar_container = ax.bar(names, values, color="grey", tick_label=names) 1262 1263 ax.bar_label( 1264 bar_container, 1265 fmt=lambda x: f"{x:.2f}", 1266 fontsize=default_ax_settings.barlabel_fontsize, 1267 ) 1268 1269 fig, ax = default_ax_settings.change(figure=(fig, ax)) 1270 fig, ax = fig_settings.change(figure=(fig, ax)) 1271 1272 return fig, ax 1273 1274 1275def plot_outlet_stats( 1276 values_sim: np.ndarray | None = None, 1277 values_obs: np.ndarray | None = None, 1278 names: np.ndarray = [], 1279 columns: list | None = [], 1280 stat: str = "max", 1281 figure: list | tuple | None = None, 1282 ax_settings: dict | ax_properties = {}, 1283 fig_settings: dict | fig_properties = {}, 1284): 1285 """ 1286 Plot a statistical criteria at a given list of outlet. 1287 Parameters: 1288 ----------- 1289 values_sim: np.ndarray or None 1290 The result of the simulated stat for all outlets. 1291 values_obs: np.ndarray or None 1292 The result of the observed stat for all outlets. 1293 names: np.ndarray 1294 Outlets name or code stored in an np.ndarray. 1295 columns: list | None 1296 Columns of the np.ndarray to plot 1297 stat: str 1298 Criteria to plot. choice are ['max', 'min', 'mean', 'median', 'q20', 'q80']" 1299 figure: tuple 1300 input figure as (fig,ax) to add a new curve 1301 ax_settings: dict or class ax_properties 1302 object or dict with any attribute of class ax_properties 1303 fig_settings: dict or class ax_properties 1304 object or dict with any attribute of class fig_settings 1305 """ 1306 1307 if isinstance(ax_settings, dict): 1308 default_ax_settings = ax_properties( 1309 ylabel=f"{stat} criteria", 1310 xlabel="Gauges stations", 1311 grid=True, 1312 legend=True, 1313 xticklabels_rotation=45, 1314 xtics_fontsize=6, 1315 ) 1316 default_ax_settings.update(**ax_settings) 1317 else: 1318 default_ax_settings = ax_properties(**ax_settings.__dict__) 1319 1320 if isinstance(fig_settings, dict): 1321 fig_settings = fig_properties(**fig_settings) 1322 else: 1323 fig_settings = fig_properties(**fig_settings.__dict__) 1324 1325 if columns is not None: 1326 if values_sim is not None: 1327 values_sim = values_sim[columns] 1328 1329 if values_obs is not None: 1330 values_obs = values_obs[columns] 1331 1332 names = names[columns] 1333 1334 if np.all(values_obs == -99.0): 1335 values_obs = None 1336 1337 if values_sim is not None and values_obs is not None: 1338 if values_obs.size != values_sim.size: 1339 raise ValueError("values_sim and values_obs must have the same size !") 1340 1341 if figure is None: 1342 fig, ax = plt.subplots() 1343 else: 1344 fig, ax = figure 1345 1346 fig, ax = default_ax_settings.change(figure=(fig, ax)) 1347 1348 x = np.arange(len(names)) 1349 width = 0.25 # the width of the bars 1350 1351 multiplier = 0 1352 1353 if values_sim is not None: 1354 offset = width * multiplier 1355 ax.bar(x + offset, values_sim, width, label="obs") 1356 multiplier += 1 1357 # ax.bar_label(rects, padding=3) 1358 1359 if values_obs is not None: 1360 offset = width * multiplier 1361 ax.bar(x + offset, values_obs, width, label="sim") 1362 # ax.bar_label(rects, padding=3) 1363 # multiplier += 1 1364 1365 ax.set_xticks(x + width, names) 1366 1367 # bar_container = ax.bar(names, values, color="grey", tick_label=names) 1368 1369 # ax.bar_label( 1370 # bar_container, 1371 # fmt=lambda x: f"{x:.2f}", 1372 # fontsize=default_ax_settings.barlabel_fontsize, 1373 # ) 1374 1375 fig, ax = default_ax_settings.change(figure=(fig, ax)) 1376 fig, ax = fig_settings.change(figure=(fig, ax)) 1377 1378 return fig, ax 1379 1380 1381def plot_misfit_map( 1382 values: np.ndarray = [], 1383 names: np.ndarray = [], 1384 mesh=None, 1385 misfit: str = "nse", 1386 coef_hydro=99.0, 1387 catchment_polygon: None | DataFrame = None, 1388 ax_settings: dict | ax_properties = {}, 1389 fig_settings: dict | fig_properties = {}, 1390 plot_settings: dict | plot_properties = {}, 1391): 1392 """ 1393 Map plot of the misfit criteria between the simulated and observed discharges. 1394 Parameters: 1395 ----------- 1396 values: np.ndarray 1397 The result of the discharge misfit for all outlets. 1398 names: np.ndarray 1399 Outlets name or code stored in an np.ndarray. 1400 mesh: None | dict 1401 The mesh of the Smash model as dict 1402 misfit: str 1403 Criteria to plot. choice are ['nse', 'nnse', 'rmse', 'nrmse', 'se', 'kge']" 1404 figure: tuple 1405 input figure as (fig,ax) to add a new curve 1406 ax_settings: dict or class ax_properties 1407 object or dict with any attribute of class ax_properties 1408 fig_settings: dict or class ax_properties 1409 object or dict with any attribute of class fig_settings 1410 plot_settings_sim: dict or class ax_properties 1411 object or dict with any attribute of class plot_settings. 1412 """ 1413 if mesh is not None: 1414 if isinstance(mesh, dict): 1415 pass 1416 else: 1417 raise ValueError("mesh must be a dict") 1418 else: 1419 raise ValueError( 1420 "model or mesh are mandatory and must be a dict or a smash Model object" 1421 ) 1422 1423 if isinstance(ax_settings, dict): 1424 default_ax_settings = ax_properties( 1425 title=f"Map of {misfit} criteria over the domain.", 1426 xlabel="x_coords", 1427 ylabel="y_coords", 1428 cmap="turbo_r", 1429 ) 1430 default_ax_settings.update(**ax_settings) 1431 else: 1432 default_ax_settings = ax_properties(**ax_settings.__dict__) 1433 1434 if isinstance(fig_settings, dict): 1435 fig_settings = fig_properties(**fig_settings) 1436 else: 1437 fig_settings = fig_properties(**fig_settings.__dict__) 1438 1439 if isinstance(plot_settings, dict): 1440 default_plot_settings = plot_properties( 1441 marker="o", 1442 markersize=8, 1443 ) 1444 default_plot_settings.update(**plot_settings) 1445 else: 1446 default_plot_settings = plot_properties(**plot_settings.__dict__) 1447 1448 # unset attribute color, managed separatly 1449 delattr(default_plot_settings, "color") 1450 1451 gauge = mesh["gauge_pos"] 1452 stations = mesh["code"] 1453 flow_acc = mesh["flwacc"] 1454 na = mesh["active_cell"] == 0 1455 1456 bbox = geo_toolbox.get_bbox_from_smash_mesh(mesh) 1457 extent = (bbox["left"], bbox["right"], bbox["bottom"], bbox["top"]) 1458 1459 flow_accum_bv = np.where(na, 0.0, flow_acc.data) 1460 surfmin = (1.0 - coef_hydro / 100.0) * np.max(flow_accum_bv) 1461 mask_flow = flow_accum_bv < surfmin 1462 flow_plot = np.where(mask_flow, np.nan, flow_accum_bv.data) 1463 flow_plot = np.where(na, np.nan, flow_plot) 1464 1465 plt.rcParams.update(plt.rcParamsDefault) 1466 fig, ax = plt.subplots() 1467 fig, ax = default_ax_settings.change(figure=(fig, ax)) 1468 1469 active_cell = np.where(na, np.nan, mesh["active_cell"]) 1470 cmap = ListedColormap(["lightgray"]) 1471 ax.imshow(active_cell, cmap=cmap, extent=extent) 1472 1473 myblues = matplotlib.colormaps["binary"] 1474 cmp = ListedColormap(myblues(np.linspace(0.20, 1.0, 265))) 1475 im = ax.imshow(flow_plot, cmap=cmp, extent=extent) 1476 1477 if catchment_polygon is not None: 1478 # catchment_polygon = gpd.read_file(outlets_shapefile) 1479 catchment_polygon.plot(ax=ax, facecolor="none", edgecolor="black") 1480 1481 # create an axes on the right side of ax. The width of cax will be 5% 1482 # of ax and the padding between cax and ax will be fixed at 0.05 inch. 1483 divider = make_axes_locatable(ax) 1484 cax = divider.append_axes("right", size="5%", pad=0.05) 1485 1486 fig.colorbar( 1487 im, cmap="Blues", ax=ax, label="Cumulated surface (km²)", shrink=0.75, cax=cax 1488 ) 1489 1490 # define bounds for the colormap 1491 if misfit == "nse" or misfit == "nnse": 1492 vmin = 0 1493 vmax = 1 1494 elif misfit == "rmse" or misfit == "nrmse" or misfit == "se": 1495 vmin = 0 1496 vmax = np.max(values) 1497 else: 1498 vmin = np.min(values) 1499 vmax = np.max(values) 1500 1501 colormap = cm.get_cmap(default_ax_settings.cmap) 1502 cmp = ListedColormap(colormap(np.linspace(vmin, vmax, 256))) 1503 1504 ha = "right" 1505 for i in range(len(stations)): 1506 1507 if ha == "right": 1508 ha = "left" 1509 str_val = str(np.round(values[i], 2)).rjust(int(len(stations[i]))) 1510 code = f"{stations[i]}\n {str_val}" 1511 1512 else: 1513 ha = "right" 1514 str_val = str(np.round(values[i], 2)).ljust(int(len(stations[i]))) 1515 code = f"{stations[i]}\n {str_val}" 1516 1517 coord = geo_toolbox.rowcol_to_xy( 1518 gauge[i][0], 1519 gauge[i][1], 1520 mesh["xmin"], 1521 mesh["ymax"], 1522 mesh["xres"], 1523 mesh["yres"], 1524 ) 1525 1526 ax.plot( 1527 coord[0], 1528 coord[1], 1529 color=cmp(values[i]), 1530 **default_plot_settings.__dict__, 1531 ) 1532 1533 ax.annotate( 1534 code, # this is the text 1535 # these are the coordinates to position the label 1536 (coord[0], coord[1]), 1537 textcoords="data", # how to position the text 1538 xytext=(coord[0], coord[1]), # distance from text to points (x,y) 1539 ha=ha, # horizontal alignment can be left, right or center 1540 color=cmp(values[i]), 1541 fontsize=default_ax_settings.annotate_fontsize 1542 * default_ax_settings.font_ratio, 1543 ) 1544 1545 import matplotlib as mpl 1546 1547 norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax) 1548 # create an axes on the right side of ax. The width of cax will be 5% 1549 # of ax and the padding between cax and ax will be fixed at 0.05 inch. 1550 # divider = make_axes_locatable(ax) 1551 cax = divider.append_axes("right", size="5%", pad=0.5) 1552 1553 fig.colorbar( 1554 cm.ScalarMappable(norm=norm, cmap=cmp), 1555 cmap=cmp, 1556 ax=ax, 1557 cax=cax, 1558 label=misfit, 1559 shrink=0.75, 1560 location="right", 1561 ) 1562 1563 fig, ax = default_ax_settings.change(figure=(fig, ax)) 1564 1565 fig, ax = fig_settings.change(figure=(fig, ax)) 1566 1567 return fig, ax 1568 1569 1570# def _ax_settings( 1571# figure, 1572# title: str = None, 1573# xlabel: str = None, 1574# ylabel: str = None, 1575# clabel: str = None, 1576# font_ratio: int = 1, 1577# title_fontsize: int = 12, 1578# label_fontsize: int = 10, 1579# grid: bool = True, 1580# xscale: str | None = None, 1581# yscale: str | None = None, 1582# legend: bool = True, 1583# legend_loc: str = None, 1584# legend_fontsize: int = 8, 1585# xtics_fontsize: int = 8, 1586# ytics_fontsize: int = 8, 1587# cmap: str | None = None, 1588# xlim: tuple | list | None = (None, None), 1589# ylim: tuple | list | None = (None, None), 1590# ): 1591 1592# fig, ax = figure 1593 1594# plt.rcParams.update(plt.rcParamsDefault) 1595 1596# if title is not None: 1597# ax.set_title(title, fontsize=title_fontsize * font_ratio) 1598 1599# if xlabel is not None: 1600# ax.set_xlabel(xlabel) 1601 1602# if ylabel is not None: 1603# ax.set_ylabel(ylabel) 1604 1605# if grid: 1606# ax.grid(True, which="both", linestyle="--", alpha=0.5) 1607 1608# if xscale is not None: 1609# ax.set_xscale(xscale) 1610 1611# if yscale is not None: 1612# ax.set_xyscale(yscale) 1613 1614# if legend: 1615# ax.legend(loc=legend_loc) 1616 1617# if cmap is not None: 1618# plt.rc("image", cmap=cmap) 1619 1620# if ylim[0] != None: 1621# ax.set_ylim(bottom=ylim[0]) 1622 1623# if ylim[1] != None: 1624# ax.set_ylim(top=ylim[1]) 1625 1626# if xlim[0] != None: 1627# ax.set_xlim(left=xlim[0]) 1628 1629# if xlim[1] != None: 1630# ax.set_xlim(right=xlim[1]) 1631 1632# plt.rcParams.update( 1633# { 1634# "axes.labelsize": label_fontsize * font_ratio, 1635# "axes.titlesize": title_fontsize * font_ratio, 1636# "legend.fontsize": legend_fontsize * font_ratio, 1637# "figure.titlesize": title_fontsize * font_ratio, 1638# "xtick.labelsize": xtics_fontsize * font_ratio, 1639# "ytick.labelsize": ytics_fontsize * font_ratio, 1640# } 1641# ) 1642 1643# return fig, ax 1644 1645 1646# def _fig_settings( 1647# figure, 1648# figname=None, 1649# xsize=8, 1650# ysize=6, 1651# transparent=False, 1652# dpi=80, 1653# font_ratio=1, 1654# bbox_inches="tight", 1655# ): 1656 1657# fig, ax = figure 1658 1659# fig.set_figheight(ysize) 1660# fig.set_figwidth(xsize) 1661 1662# plt.rc( 1663# "font", size=plt.rcParams["font.size"] * font_ratio 1664# ) # controls default text sizes 1665# plt.rc( 1666# "axes", titlesize=plt.rcParams["axes.titlesize"] * font_ratio 1667# ) # fontsize of the axes title 1668# plt.rc( 1669# "axes", labelsize=plt.rcParams["axes.labelsize"] * font_ratio 1670# ) # fontsize of the x and y labels 1671# plt.rc( 1672# "xtick", labelsize=plt.rcParams["xtick.labelsize"] * font_ratio 1673# ) # fontsize of the tick labels 1674# plt.rc( 1675# "ytick", labelsize=plt.rcParams["ytick.labelsize"] * font_ratio 1676# ) # fontsize of the tick labels 1677# plt.rc( 1678# "legend", fontsize=plt.rcParams["legend.fontsize"] * font_ratio 1679# ) # legend fontsize 1680# plt.rc( 1681# "figure", titlesize=plt.rcParams["figure.titlesize"] * font_ratio 1682# ) # fontsize of the figure title 1683 1684# if figname is not None: 1685# fig.savefig(figname, transparent=transparent, dpi=dpi, bbox_inches=bbox_inches) 1686 1687# return fig, ax
26class plot_properties: 27 """Class which handle differents properties of the matplotlib plot function. 28 All attributes can be defined by the user and the object plot_properties can be 29 passed to any smashbox plot function. 30 """ 31 32 def __init__( 33 self, 34 ls="-", 35 lw=1.5, 36 marker="", 37 markersize=4, 38 color="black", 39 label="", 40 ): 41 self.ls = ls 42 """The style of the line (see matplotlib documentation)""" 43 self.lw = lw 44 """The linewidth of the line (see matplotlib documentation)""" 45 self.marker = marker 46 """The style of the marker (see matplotlib documentation)""" 47 self.markersize = markersize 48 """The size of the marker (see matplotlib documentation)""" 49 self.color = color 50 """The color of the line (see matplotlib documentation)""" 51 self.label = label 52 """The label of the line (see matplotlib documentation)""" 53 54 def update(self, **kwargs): 55 """Update the class attributes using kwarg (dictionnary)""" 56 for key, values in kwargs.items(): 57 setattr(self, key, values)
Class which handle differents properties of the matplotlib plot function. All attributes can be defined by the user and the object plot_properties can be passed to any smashbox plot function.
32 def __init__( 33 self, 34 ls="-", 35 lw=1.5, 36 marker="", 37 markersize=4, 38 color="black", 39 label="", 40 ): 41 self.ls = ls 42 """The style of the line (see matplotlib documentation)""" 43 self.lw = lw 44 """The linewidth of the line (see matplotlib documentation)""" 45 self.marker = marker 46 """The style of the marker (see matplotlib documentation)""" 47 self.markersize = markersize 48 """The size of the marker (see matplotlib documentation)""" 49 self.color = color 50 """The color of the line (see matplotlib documentation)""" 51 self.label = label 52 """The label of the line (see matplotlib documentation)"""
60class ax_properties: 61 """Class which handle differents properties of the matplotlib ax object 62 (see matplotlib documentation). All attributes can be defined by the user and the 63 object ax_properties can be passed to any smashbox plot function. 64 """ 65 66 def __init__( 67 self, 68 title: str = None, 69 xlabel: str = None, 70 ylabel: str = None, 71 clabel: str = None, 72 font_ratio: int = 1, 73 title_fontsize: int = 12, 74 label_fontsize: int = 10, 75 annotate_fontsize: int = 6, 76 grid: bool = True, 77 xscale: str | None = None, 78 yscale: str | None = None, 79 legend: bool = True, 80 legend_loc: str = None, 81 legend_fontsize: int = 8, 82 xtics_fontsize: int = 8, 83 ytics_fontsize: int = 8, 84 cmap: str | None = None, 85 xlim: tuple | list | None = (None, None), 86 ylim: tuple | list | None = (None, None), 87 xticklabels_rotation: int = 0, 88 barlabel_fontsize=6, 89 ): 90 91 self.title = title 92 """The title of the graphic""" 93 self.xlabel = xlabel 94 """The label of the x axis""" 95 self.ylabel = ylabel 96 """The label of the y axis""" 97 self.clabel = clabel 98 """The label of the colorbar""" 99 self.font_ratio = font_ratio 100 """Ratio of the global fontsize""" 101 self.title_fontsize = title_fontsize 102 """The fontsize of the title""" 103 self.label_fontsize = label_fontsize 104 """The label fontsize""" 105 self.annotate_fontsize = annotate_fontsize 106 """The annotation fontsize in plot""" 107 self.grid = grid 108 """Set the grid (boolean), default True""" 109 self.xscale = xscale 110 """Scale of the x axis (see matplotlib documentation)""" 111 self.yscale = yscale 112 """Scale of the x axis (see matplotlib documentation)""" 113 self.legend = legend 114 """Set the legend, boolean, default is True""" 115 self.legend_loc = legend_loc 116 """Localisation of the legend""" 117 self.legend_fontsize = legend_fontsize 118 """The fontsize of the legend""" 119 self.xtics_fontsize = xtics_fontsize 120 """The fontsize of the xtics""" 121 self.ytics_fontsize = ytics_fontsize 122 """The fontsize of the ytics""" 123 self.cmap = cmap 124 """The name of the used colormap""" 125 self.xlim = xlim 126 """The limit of the x axis, tuple or list, default is (None,None)""" 127 self.ylim = ylim 128 """The limit of the y axis, tuple or list, default is (None,None)""" 129 self.xticklabels_rotation = xticklabels_rotation 130 """Angle of the xtics labels, float, default is 0.""" 131 self.barlabel_fontsize = barlabel_fontsize * self.font_ratio 132 "Fontsize of the bar top label" 133 134 def update(self, **kwargs): 135 """Update the class attributes using kwarg (dictionnary)""" 136 for key, values in kwargs.items(): 137 setattr(self, key, values) 138 139 def change(self, figure): 140 """Apply change to the current figure `figure` 141 Parameter: 142 ---------- 143 figure : list or tuple 144 list of (fig, ax), figure and ax of a matplotlib subplot. 145 Return: 146 ------- 147 a tuple of the modified (fig,ax) 148 """ 149 fig, ax = figure 150 151 plt.rcParams.update(plt.rcParamsDefault) 152 153 plt.rcParams.update( 154 { 155 "axes.labelsize": self.label_fontsize * self.font_ratio, 156 "axes.titlesize": self.title_fontsize * self.font_ratio, 157 "legend.fontsize": self.legend_fontsize * self.font_ratio, 158 "figure.titlesize": self.title_fontsize * self.font_ratio, 159 "xtick.labelsize": self.xtics_fontsize * self.font_ratio, 160 "ytick.labelsize": self.ytics_fontsize * self.font_ratio, 161 } 162 ) 163 164 if self.title is not None: 165 166 ax.set_title(self.title, fontsize=self.title_fontsize * self.font_ratio) 167 168 if self.xlabel is not None: 169 ax.set_xlabel(self.xlabel, fontsize=self.label_fontsize * self.font_ratio) 170 171 if self.ylabel is not None: 172 ax.set_ylabel(self.ylabel, fontsize=self.label_fontsize * self.font_ratio) 173 174 if self.grid: 175 ax.grid(True, which="both", linestyle="--", alpha=0.5) 176 177 if self.xscale is not None: 178 ax.set_xscale(self.xscale) 179 180 if self.yscale is not None: 181 ax.set_xyscale(self.yscale) 182 183 if self.legend: 184 ax.legend( 185 loc=self.legend_loc, fontsize=self.legend_fontsize * self.font_ratio 186 ) 187 188 if self.cmap is not None: 189 plt.rc("image", cmap=self.cmap) 190 191 if self.ylim[0] is not None: 192 ax.set_ylim(bottom=self.ylim[0]) 193 194 if self.ylim[1] is not None: 195 ax.set_ylim(top=self.ylim[1]) 196 197 if self.xlim[0] is not None: 198 ax.set_xlim(left=self.xlim[0]) 199 200 if self.xlim[1] is not None: 201 ax.set_xlim(right=self.xlim[1]) 202 203 if self.xticklabels_rotation > 0: 204 ax.set_xticklabels( 205 ax.get_xticklabels(), rotation=self.xticklabels_rotation, ha="right" 206 ) 207 208 return fig, ax
Class which handle differents properties of the matplotlib ax object (see matplotlib documentation). All attributes can be defined by the user and the object ax_properties can be passed to any smashbox plot function.
66 def __init__( 67 self, 68 title: str = None, 69 xlabel: str = None, 70 ylabel: str = None, 71 clabel: str = None, 72 font_ratio: int = 1, 73 title_fontsize: int = 12, 74 label_fontsize: int = 10, 75 annotate_fontsize: int = 6, 76 grid: bool = True, 77 xscale: str | None = None, 78 yscale: str | None = None, 79 legend: bool = True, 80 legend_loc: str = None, 81 legend_fontsize: int = 8, 82 xtics_fontsize: int = 8, 83 ytics_fontsize: int = 8, 84 cmap: str | None = None, 85 xlim: tuple | list | None = (None, None), 86 ylim: tuple | list | None = (None, None), 87 xticklabels_rotation: int = 0, 88 barlabel_fontsize=6, 89 ): 90 91 self.title = title 92 """The title of the graphic""" 93 self.xlabel = xlabel 94 """The label of the x axis""" 95 self.ylabel = ylabel 96 """The label of the y axis""" 97 self.clabel = clabel 98 """The label of the colorbar""" 99 self.font_ratio = font_ratio 100 """Ratio of the global fontsize""" 101 self.title_fontsize = title_fontsize 102 """The fontsize of the title""" 103 self.label_fontsize = label_fontsize 104 """The label fontsize""" 105 self.annotate_fontsize = annotate_fontsize 106 """The annotation fontsize in plot""" 107 self.grid = grid 108 """Set the grid (boolean), default True""" 109 self.xscale = xscale 110 """Scale of the x axis (see matplotlib documentation)""" 111 self.yscale = yscale 112 """Scale of the x axis (see matplotlib documentation)""" 113 self.legend = legend 114 """Set the legend, boolean, default is True""" 115 self.legend_loc = legend_loc 116 """Localisation of the legend""" 117 self.legend_fontsize = legend_fontsize 118 """The fontsize of the legend""" 119 self.xtics_fontsize = xtics_fontsize 120 """The fontsize of the xtics""" 121 self.ytics_fontsize = ytics_fontsize 122 """The fontsize of the ytics""" 123 self.cmap = cmap 124 """The name of the used colormap""" 125 self.xlim = xlim 126 """The limit of the x axis, tuple or list, default is (None,None)""" 127 self.ylim = ylim 128 """The limit of the y axis, tuple or list, default is (None,None)""" 129 self.xticklabels_rotation = xticklabels_rotation 130 """Angle of the xtics labels, float, default is 0.""" 131 self.barlabel_fontsize = barlabel_fontsize * self.font_ratio 132 "Fontsize of the bar top label"
134 def update(self, **kwargs): 135 """Update the class attributes using kwarg (dictionnary)""" 136 for key, values in kwargs.items(): 137 setattr(self, key, values)
Update the class attributes using kwarg (dictionnary)
139 def change(self, figure): 140 """Apply change to the current figure `figure` 141 Parameter: 142 ---------- 143 figure : list or tuple 144 list of (fig, ax), figure and ax of a matplotlib subplot. 145 Return: 146 ------- 147 a tuple of the modified (fig,ax) 148 """ 149 fig, ax = figure 150 151 plt.rcParams.update(plt.rcParamsDefault) 152 153 plt.rcParams.update( 154 { 155 "axes.labelsize": self.label_fontsize * self.font_ratio, 156 "axes.titlesize": self.title_fontsize * self.font_ratio, 157 "legend.fontsize": self.legend_fontsize * self.font_ratio, 158 "figure.titlesize": self.title_fontsize * self.font_ratio, 159 "xtick.labelsize": self.xtics_fontsize * self.font_ratio, 160 "ytick.labelsize": self.ytics_fontsize * self.font_ratio, 161 } 162 ) 163 164 if self.title is not None: 165 166 ax.set_title(self.title, fontsize=self.title_fontsize * self.font_ratio) 167 168 if self.xlabel is not None: 169 ax.set_xlabel(self.xlabel, fontsize=self.label_fontsize * self.font_ratio) 170 171 if self.ylabel is not None: 172 ax.set_ylabel(self.ylabel, fontsize=self.label_fontsize * self.font_ratio) 173 174 if self.grid: 175 ax.grid(True, which="both", linestyle="--", alpha=0.5) 176 177 if self.xscale is not None: 178 ax.set_xscale(self.xscale) 179 180 if self.yscale is not None: 181 ax.set_xyscale(self.yscale) 182 183 if self.legend: 184 ax.legend( 185 loc=self.legend_loc, fontsize=self.legend_fontsize * self.font_ratio 186 ) 187 188 if self.cmap is not None: 189 plt.rc("image", cmap=self.cmap) 190 191 if self.ylim[0] is not None: 192 ax.set_ylim(bottom=self.ylim[0]) 193 194 if self.ylim[1] is not None: 195 ax.set_ylim(top=self.ylim[1]) 196 197 if self.xlim[0] is not None: 198 ax.set_xlim(left=self.xlim[0]) 199 200 if self.xlim[1] is not None: 201 ax.set_xlim(right=self.xlim[1]) 202 203 if self.xticklabels_rotation > 0: 204 ax.set_xticklabels( 205 ax.get_xticklabels(), rotation=self.xticklabels_rotation, ha="right" 206 ) 207 208 return fig, ax
Apply change to the current figure figure
Parameter:
figure : list or tuple list of (fig, ax), figure and ax of a matplotlib subplot.
Return:
a tuple of the modified (fig,ax)
211class fig_properties: 212 """Class which handle differents properties of the matplotlib fig object 213 (see matplotlib documentation). All attributes can be defined by the user and the 214 object ax_properties can be passed to any smashbox plot function. 215 """ 216 217 def __init__( 218 self, 219 figname=None, 220 xsize=8, 221 ysize=6, 222 transparent=False, 223 dpi=160, 224 font_ratio=1, 225 bbox_inches="tight", 226 ): 227 228 self.figname = figname 229 """Path to the figure name to be saved""" 230 self.xsize = xsize 231 """Width of the figure in inch""" 232 self.ysize = ysize 233 """Height of the figure in inch""" 234 self.transparent = transparent 235 """Use transparency when exporting the figure, default is False""" 236 self.dpi = dpi 237 """RƩsolution (dpi), int, default is 80""" 238 self.font_ratio = font_ratio 239 """Global font ratio""" 240 self.bbox_inches = bbox_inches 241 """Constraint of the boundingbox of each ax (see matplotlib docuentation)""" 242 243 def update(self, **kwargs): 244 """Update the class attributes using kwarg (dictionnary)""" 245 246 for key, values in kwargs.items(): 247 setattr(self, key, values) 248 249 def change(self, figure): 250 """Apply change to the current figure `figure` 251 Parameter: 252 ---------- 253 figure : list or tuple 254 list of (fig, ax), figure and ax of a matplotlib subplot. 255 Return: 256 ------- 257 a tuple of the modified (fig,ax) 258 """ 259 260 fig, ax = figure 261 262 fig.set_figheight(self.ysize) 263 fig.set_figwidth(self.xsize) 264 265 plt.rc( 266 "font", size=plt.rcParams["font.size"] * self.font_ratio 267 ) # controls default text sizes 268 plt.rc( 269 "axes", titlesize=plt.rcParams["axes.titlesize"] * self.font_ratio 270 ) # fontsize of the axes title 271 plt.rc( 272 "axes", labelsize=plt.rcParams["axes.labelsize"] * self.font_ratio 273 ) # fontsize of the x and y labels 274 plt.rc( 275 "xtick", labelsize=plt.rcParams["xtick.labelsize"] * self.font_ratio 276 ) # fontsize of the tick labels 277 plt.rc( 278 "ytick", labelsize=plt.rcParams["ytick.labelsize"] * self.font_ratio 279 ) # fontsize of the tick labels 280 plt.rc( 281 "legend", fontsize=plt.rcParams["legend.fontsize"] * self.font_ratio 282 ) # legend fontsize 283 plt.rc( 284 "figure", titlesize=plt.rcParams["figure.titlesize"] * self.font_ratio 285 ) # fontsize of the figure title 286 287 if self.figname is not None: 288 289 head_path, basename = os.path.split(self.figname) 290 291 if len(head_path) > 0 and not os.path.exists(head_path): 292 os.makedirs(head_path) 293 294 fig.savefig( 295 self.figname, 296 transparent=self.transparent, 297 dpi=self.dpi, 298 bbox_inches=self.bbox_inches, 299 ) 300 301 return fig, ax
Class which handle differents properties of the matplotlib fig object (see matplotlib documentation). All attributes can be defined by the user and the object ax_properties can be passed to any smashbox plot function.
217 def __init__( 218 self, 219 figname=None, 220 xsize=8, 221 ysize=6, 222 transparent=False, 223 dpi=160, 224 font_ratio=1, 225 bbox_inches="tight", 226 ): 227 228 self.figname = figname 229 """Path to the figure name to be saved""" 230 self.xsize = xsize 231 """Width of the figure in inch""" 232 self.ysize = ysize 233 """Height of the figure in inch""" 234 self.transparent = transparent 235 """Use transparency when exporting the figure, default is False""" 236 self.dpi = dpi 237 """RƩsolution (dpi), int, default is 80""" 238 self.font_ratio = font_ratio 239 """Global font ratio""" 240 self.bbox_inches = bbox_inches 241 """Constraint of the boundingbox of each ax (see matplotlib docuentation)"""
243 def update(self, **kwargs): 244 """Update the class attributes using kwarg (dictionnary)""" 245 246 for key, values in kwargs.items(): 247 setattr(self, key, values)
Update the class attributes using kwarg (dictionnary)
249 def change(self, figure): 250 """Apply change to the current figure `figure` 251 Parameter: 252 ---------- 253 figure : list or tuple 254 list of (fig, ax), figure and ax of a matplotlib subplot. 255 Return: 256 ------- 257 a tuple of the modified (fig,ax) 258 """ 259 260 fig, ax = figure 261 262 fig.set_figheight(self.ysize) 263 fig.set_figwidth(self.xsize) 264 265 plt.rc( 266 "font", size=plt.rcParams["font.size"] * self.font_ratio 267 ) # controls default text sizes 268 plt.rc( 269 "axes", titlesize=plt.rcParams["axes.titlesize"] * self.font_ratio 270 ) # fontsize of the axes title 271 plt.rc( 272 "axes", labelsize=plt.rcParams["axes.labelsize"] * self.font_ratio 273 ) # fontsize of the x and y labels 274 plt.rc( 275 "xtick", labelsize=plt.rcParams["xtick.labelsize"] * self.font_ratio 276 ) # fontsize of the tick labels 277 plt.rc( 278 "ytick", labelsize=plt.rcParams["ytick.labelsize"] * self.font_ratio 279 ) # fontsize of the tick labels 280 plt.rc( 281 "legend", fontsize=plt.rcParams["legend.fontsize"] * self.font_ratio 282 ) # legend fontsize 283 plt.rc( 284 "figure", titlesize=plt.rcParams["figure.titlesize"] * self.font_ratio 285 ) # fontsize of the figure title 286 287 if self.figname is not None: 288 289 head_path, basename = os.path.split(self.figname) 290 291 if len(head_path) > 0 and not os.path.exists(head_path): 292 os.makedirs(head_path) 293 294 fig.savefig( 295 self.figname, 296 transparent=self.transparent, 297 dpi=self.dpi, 298 bbox_inches=self.bbox_inches, 299 ) 300 301 return fig, ax
Apply change to the current figure figure
Parameter:
figure : list or tuple list of (fig, ax), figure and ax of a matplotlib subplot.
Return:
a tuple of the modified (fig,ax)
95 def wrapper(*args, **kwargs): 96 97 bound = sig.bind(*args, **kwargs) 98 bound.apply_defaults() 99 100 for name, value in bound.arguments.items(): 101 if name in annotations: 102 103 target_type = annotations[name] 104 105 args_ = get_args(target_type) 106 107 if target_type is None and len(args_) == 0: 108 args_ = (type(None),) 109 target_type = type(None) 110 111 if not type(value) in args_: 112 113 if len(args_) > 1 and type(None) in args_: 114 115 converted = False 116 for t in args_: 117 118 if t is not type(None): 119 120 if value is not None: 121 try: 122 print( 123 f"</> Warning: Arg '{name}' of type {type(value)} is being" 124 f" converted to {t}" 125 ) 126 bound.arguments[name] = t(value) 127 converted = True 128 except: 129 pass 130 131 if converted: 132 break 133 134 if not converted: 135 raise TypeError( 136 f"</> Error: Arg '{name}' must be a type of " 137 f" {args_}, got {value}" 138 f" ({type(value).__name__})" 139 ) 140 141 else: 142 if not isinstance(value, target_type): 143 try: 144 print( 145 f"</> Warning: Arg '{name}' of type {type(value)} is being" 146 f" converted to {target_type}" 147 ) 148 bound.arguments[name] = target_type(value) 149 except Exception: 150 raise TypeError( 151 f"</> Error: Arg '{name}' must be a type of " 152 f" {target_type.__name__}, got {value}" 153 f" ({type(value).__name__})" 154 ) 155 156 return func(*bound.args, **bound.kwargs)
Save a figure.
Parameters:
fig: fig object returned by matplotlib.subplot the figure to save figname: str Path to the figure xsize: int width of the figure in inch ysize: int height of the figure in inch transparent: bool, default is False use transparency dpi : int resolution of the figure, default is 80
329def generate_palette(base_color, n, variation="hue"): 330 """ 331 Generate a palette of colors from a base color. 332 Parameter: 333 --------- 334 base_color: str 335 matplotlib color string 336 n: int 337 number of color to generate 338 variation: 'hue' | 'brightness' 339 how to generate the color palette, by changing the hue or the brighness of the base color. 340 Return: a list of colors 341 """ 342 # Convertir la couleur de base en format RGB normalisƩ (0-1) 343 rgb = mcolors.to_rgb(base_color) 344 345 # Convertir en HSV 346 h, s, v = colorsys.rgb_to_hsv(*rgb) 347 348 # GƩnƩrer n couleurs en modifiant la teinte ou valeur 349 palette = [] 350 for i in range(n): 351 new_h = h 352 if variation == "hue": 353 new_h = (h + i / n) % 1.0 # cycle dans le cercle chromatique 354 new_v = v 355 if variation == "brightness": 356 new_v = max(0.1, min(1.0, v * (0.5 + i / (2 * n)))) # Ʃviter le noir complet 357 358 new_rgb = colorsys.hsv_to_rgb(new_h, s, new_v) 359 palette.append(new_rgb) 360 361 return palette
Generate a palette of colors from a base color.
Parameter:
base_color: str matplotlib color string n: int number of color to generate variation: 'hue' | 'brightness' how to generate the color palette, by changing the hue or the brighness of the base color. Return: a list of colors
95 def wrapper(*args, **kwargs): 96 97 bound = sig.bind(*args, **kwargs) 98 bound.apply_defaults() 99 100 for name, value in bound.arguments.items(): 101 if name in annotations: 102 103 target_type = annotations[name] 104 105 args_ = get_args(target_type) 106 107 if target_type is None and len(args_) == 0: 108 args_ = (type(None),) 109 target_type = type(None) 110 111 if not type(value) in args_: 112 113 if len(args_) > 1 and type(None) in args_: 114 115 converted = False 116 for t in args_: 117 118 if t is not type(None): 119 120 if value is not None: 121 try: 122 print( 123 f"</> Warning: Arg '{name}' of type {type(value)} is being" 124 f" converted to {t}" 125 ) 126 bound.arguments[name] = t(value) 127 converted = True 128 except: 129 pass 130 131 if converted: 132 break 133 134 if not converted: 135 raise TypeError( 136 f"</> Error: Arg '{name}' must be a type of " 137 f" {args_}, got {value}" 138 f" ({type(value).__name__})" 139 ) 140 141 else: 142 if not isinstance(value, target_type): 143 try: 144 print( 145 f"</> Warning: Arg '{name}' of type {type(value)} is being" 146 f" converted to {target_type}" 147 ) 148 bound.arguments[name] = target_type(value) 149 except Exception: 150 raise TypeError( 151 f"</> Error: Arg '{name}' must be a type of " 152 f" {target_type.__name__}, got {value}" 153 f" ({type(value).__name__})" 154 ) 155 156 return func(*bound.args, **bound.kwargs)
Plot a temporal chonic of values
Parameters:
data: np.ndarray of dimenion 2. data to plot as a matrix of 2 dimension. t_axis : int the axis of the time in data, default is 1 outlets_name: list the list of the outlets name columns : list the column to be plotted in t_axis direction dt: float the timestep xtics : list list of date for the xtics. The format must be automatically read by numpy.Datetime date_range: list list of [date_start, date_end, timedelta] to generate the xtics figure: tuple input figure as (fig,ax) to add a new curve ax_settings: dict or class ax_properties object or dict with any attribute of class ax_properties fig_settings: dict or class ax_properties object or dict with any attribute of class fig_settings plot_settings: dict or class plot_properties object or dict with any attribute of class plot_settings
470def plot_hydrograph( 471 model: Model | None = None, 472 columns: list | tuple = [], 473 outlets_name: list | tuple = [], 474 plot_rainfall: bool = True, 475 plot_qobs: bool = True, 476 figure: list | tuple | None = None, 477 ax_settings: dict | ax_properties = {}, 478 fig_settings: dict | fig_properties = {}, 479 plot_settings_sim: dict | plot_properties = {}, 480 plot_settings_obs: dict | plot_properties = {}, 481): 482 """ 483 Plot an hydrograph from a smash model 484 Parameters: 485 ----------- 486 model: a smash model object 487 a smash model object 488 outlets_name: list 489 the list of the outlets name 490 columns : list 491 the column to be plotted in t_axis direction 492 figure: tuple 493 input figure as (fig,ax) to add a new curve 494 ax_settings: dict or class ax_properties 495 object or dict with any attribute of class ax_properties 496 fig_settings: dict or class ax_properties 497 object or dict with any attribute of class fig_settings 498 plot_settings_sim: dict or class ax_properties 499 object or dict with any attribute of class plot_settings.Control the simulated curve 500 plot_settings_obs: dict or class ax_properties 501 object or dict with any attribute of class plot_settings. Control the observed curve 502 503 """ 504 if model is None: 505 raise ValueError("Input smash model object is None.") 506 507 if isinstance(ax_settings, dict): 508 default_ax_settings = ax_properties( 509 xlabel="Time", ylabel="discharges m^3/s", xtics_fontsize=10, ytics_fontsize=10 510 ) 511 default_ax_settings.update(**ax_settings) 512 else: 513 default_ax_settings = ax_properties(**ax_settings.__dict__) 514 515 if isinstance(fig_settings, dict): 516 fig_settings = fig_properties(**fig_settings) 517 else: 518 fig_settings = fig_properties(**fig_settings.__dict__) 519 520 if isinstance(plot_settings_sim, dict): 521 default_plot_settings_sim = plot_properties( 522 ls="-", 523 lw="2", 524 marker="", 525 markersize=4, 526 color="blue", 527 label="Sim", 528 ) 529 default_plot_settings_sim.update(**plot_settings_sim) 530 else: 531 default_plot_settings_sim = plot_properties(**plot_settings_sim.__dict__) 532 533 # default color for multi curves: same color but different line type 534 if len(columns) >= 2: 535 color = "blue" 536 else: 537 color = "black" 538 539 if isinstance(plot_settings_obs, dict): 540 default_plot_settings_obs = plot_properties( 541 ls="--", 542 lw="1.5", 543 marker="", 544 markersize=4, 545 color=color, 546 label="Obs", 547 ) 548 default_plot_settings_obs.update(**plot_settings_obs) 549 else: 550 default_plot_settings_obs = plot_properties() 551 552 # manage date here 553 date_deb = datetime.datetime.fromisoformat( 554 model.setup.start_time 555 ) + datetime.timedelta(seconds=int(model.setup.dt)) 556 date_end = datetime.datetime.fromisoformat(model.setup.end_time) 557 date_range = [date_deb, date_end, model.setup.dt] 558 559 if figure is None: 560 if plot_rainfall: 561 fig, (ax2, ax1) = plt.subplots(2, 1, height_ratios=[1, 4]) 562 fig.subplots_adjust(hspace=0) 563 figure = [fig, ax1, ax2] 564 else: 565 fig, ax2 = plt.subplots() 566 figure = [fig, ax1] 567 else: 568 if plot_rainfall: 569 fig = figure[0] 570 ax1 = figure[1] 571 ax2 = figure[2] 572 else: 573 fig = figure[0] 574 ax1 = figure[1] 575 576 fig, ax = default_ax_settings.change(figure=(fig, ax1)) 577 578 if plot_qobs: 579 fig, ax1 = plot_chro( 580 np.where(model.response_data.q<0, np.nan, model.response_data.q), 581 date_range=date_range, 582 columns=columns, 583 outlets_name=["obs_" + name for name in outlets_name], 584 figure=(fig, ax1), 585 ax_settings=default_ax_settings, 586 fig_settings=fig_settings, 587 plot_settings=default_plot_settings_obs, 588 ) 589 590 fig, ax1 = plot_chro( 591 model.response.q, 592 date_range=date_range, 593 columns=columns, 594 outlets_name=["sim_" + name for name in outlets_name], 595 figure=(fig, ax1), 596 ax_settings=default_ax_settings, 597 fig_settings=fig_settings, 598 plot_settings=default_plot_settings_sim, 599 ) 600 601 xtics = np.arange( 602 np.datetime64(date_range[0]), 603 np.datetime64(date_range[1] + datetime.timedelta(seconds=int(date_range[2]))), 604 np.timedelta64(int(date_range[2]), "s"), 605 ) 606 607 axes_list = [ax1] 608 609 if plot_rainfall: 610 611 if len(columns) > 0: 612 col = columns[0] 613 else: 614 col = 0 615 616 ax2.bar( 617 xtics[:], 618 model.atmos_data.mean_prcp[col, :], 619 label="Average rainfall (mm)", 620 width=np.timedelta64(int(date_range[2]), "s"), 621 color="blue", 622 ) 623 624 ax2.invert_yaxis() 625 ax2.grid(alpha=0.7, ls="--") 626 ax2.get_xaxis().set_visible(False) 627 ax2.set_ylim(bottom=1.2 * max(model.atmos_data.mean_prcp[0, :]), top=0.0) 628 ax2.set_ylabel("Average rainfall (mm)") 629 630 axes_list.append(ax2) 631 632 fig, ax = fig_settings.change(figure=(fig, tuple(axes_list))) 633 634 return fig, ax
Plot an hydrograph from a smash model
Parameters:
model: a smash model object a smash model object outlets_name: list the list of the outlets name columns : list the column to be plotted in t_axis direction figure: tuple input figure as (fig,ax) to add a new curve ax_settings: dict or class ax_properties object or dict with any attribute of class ax_properties fig_settings: dict or class ax_properties object or dict with any attribute of class fig_settings plot_settings_sim: dict or class ax_properties object or dict with any attribute of class plot_settings.Control the simulated curve plot_settings_obs: dict or class ax_properties object or dict with any attribute of class plot_settings. Control the observed curve
637def plot_catchment_surface_error( 638 mesh: dict = None, 639 ax_settings: dict | ax_properties = {}, 640 fig_settings: dict | fig_properties = {}, 641): 642 """ 643 Plot the misfit criteria between the simulated and observed discharges. 644 Parameters: 645 ----------- 646 values: np.ndarray 647 The result of the discharge misfit for all outlets. 648 names: np.ndarray 649 Outlets name or code stored in an np.ndarray. 650 columns: list | None 651 Columns of the np.ndarray to plot 652 misfit: str 653 Criteria to plot. choice are ['nse', 'nnse', 'rmse', 'nrmse', 'se', 'kge']" 654 figure: tuple 655 input figure as (fig,ax) to add a new curve 656 ax_settings: dict or class ax_properties 657 object or dict with any attribute of class ax_properties 658 fig_settings: dict or class ax_properties 659 object or dict with any attribute of class fig_settings 660 """ 661 662 if isinstance(ax_settings, dict): 663 default_ax_settings = ax_properties( 664 title="Catchment surface error (Ssim-Sobs)/Sobs *100", 665 ylabel="Surface error %", 666 xlabel="Outlets", 667 xticklabels_rotation=45, 668 xtics_fontsize=6, 669 ) 670 default_ax_settings.update(**ax_settings) 671 else: 672 default_ax_settings = ax_properties(**ax_settings.__dict__) 673 674 if isinstance(fig_settings, dict): 675 fig_settings = fig_properties(**fig_settings) 676 else: 677 fig_settings = fig_properties(**fig_settings.__dict__) 678 679 if len(mesh["code"]) == 0: 680 print("Cannot plot this, the mesh has no gauge !") 681 return None, None 682 683 plt.rcParams.update(plt.rcParamsDefault) 684 fig, ax = plt.subplots() 685 fig, ax = default_ax_settings.change(figure=(fig, ax)) 686 687 surface_error = (mesh["area_dln"] - mesh["area"]) / mesh["area"] * 100 688 689 fig, ax = default_ax_settings.change(figure=(fig, ax)) 690 bar_container = ax.bar( 691 mesh["code"], surface_error, color="grey", tick_label=mesh["code"] 692 ) 693 694 ax.bar_label( 695 bar_container, 696 fmt=lambda x: f"{x:.2f}", 697 fontsize=default_ax_settings.barlabel_fontsize, 698 ) 699 700 fig, ax = default_ax_settings.change(figure=(fig, ax)) 701 fig, ax = fig_settings.change(figure=(fig, ax)) 702 703 return fig, ax
Plot the misfit criteria between the simulated and observed discharges.
Parameters:
values: np.ndarray The result of the discharge misfit for all outlets. names: np.ndarray Outlets name or code stored in an np.ndarray. columns: list | None Columns of the np.ndarray to plot misfit: str Criteria to plot. choice are ['nse', 'nnse', 'rmse', 'nrmse', 'se', 'kge']" figure: tuple input figure as (fig,ax) to add a new curve ax_settings: dict or class ax_properties object or dict with any attribute of class ax_properties fig_settings: dict or class ax_properties object or dict with any attribute of class fig_settings
706def plot_catchment_surface_consistency( 707 mesh: dict = None, 708 label: bool = True, 709 ax_settings: dict | ax_properties = {}, 710 fig_settings: dict | fig_properties = {}, 711 plot_settings: dict | plot_properties = {}, 712): 713 """ 714 Plot the modeled surface vs the observed surface 715 Parameters: 716 ----------- 717 mesh: dict, optional 718 The mesh of the Smash model, defaults to None 719 ax_settings: dict or class ax_properties 720 object or dict with any attribute of class ax_properties 721 fig_settings: dict or class ax_properties 722 object or dict with any attribute of class fig_settings 723 plot_settings: dict or class plot_properties 724 object or dict with any attribute of class plot_settings.Control the simulated curve 725 """ 726 727 if isinstance(ax_settings, dict): 728 default_ax_settings = ax_properties( 729 title="Modeled and observed surface consistency", 730 ylabel="Modeled surface", 731 xlabel="Observed surface", 732 ) 733 default_ax_settings.update(**ax_settings) 734 else: 735 default_ax_settings = ax_properties(**ax_settings.__dict__) 736 737 if isinstance(fig_settings, dict): 738 fig_settings = fig_properties(**fig_settings) 739 else: 740 fig_settings = fig_properties(**fig_settings.__dict__) 741 742 if isinstance(plot_settings, dict): 743 default_plot_settings = plot_properties( 744 marker="+", 745 markersize=12, 746 color="blue", 747 ) 748 default_plot_settings.update(**plot_settings) 749 else: 750 default_plot_settings = plot_properties(**plot_settings.__dict__) 751 752 if len(mesh["code"]) == 0: 753 print("Cannot plot this, the mesh has no gauge !") 754 return None, None 755 756 surface_model = mesh["area_dln"] / 1000.0**2.0 757 surface_obs = mesh["area"] / 1000.0**2.0 758 759 plt.rcParams.update(plt.rcParamsDefault) 760 fig, ax = plt.subplots() 761 fig, ax = default_ax_settings.change(figure=(fig, ax)) 762 763 ax.plot( 764 surface_obs, 765 surface_model, 766 markersize=default_plot_settings.markersize, 767 marker=default_plot_settings.marker, 768 color=default_plot_settings.color, 769 linestyle="None", 770 ) 771 ax.plot( 772 np.linspace(min(surface_obs), max(surface_obs), 10), 773 np.linspace(min(surface_obs), max(surface_obs), 10), 774 linewidth=2, 775 color="grey", 776 ) 777 778 if label: 779 ha = ("left", "right") 780 for i, label in enumerate(mesh["code"]): 781 ax.annotate( 782 label, # this is the text 783 ( 784 surface_obs[i], 785 surface_model[i], 786 ), # these are the coordinates to position the label 787 textcoords="data", # how to position the text 788 xytext=( 789 surface_obs[i], 790 surface_model[i], 791 ), # distance from text to points (x,y) 792 ha=ha[i % 2], # horizontal alignment can be left, right or center 793 color="red", 794 fontsize=default_ax_settings.annotate_fontsize, 795 ) 796 797 ax.set(xlabel=default_ax_settings.xlabel, ylabel=default_ax_settings.ylabel) 798 799 fig, ax = fig_settings.change(figure=(fig, ax)) 800 801 return fig, ax
Plot the modeled surface vs the observed surface
Parameters:
mesh: dict, optional The mesh of the Smash model, defaults to None ax_settings: dict or class ax_properties object or dict with any attribute of class ax_properties fig_settings: dict or class ax_properties object or dict with any attribute of class fig_settings plot_settings: dict or class plot_properties object or dict with any attribute of class plot_settings.Control the simulated curve
804def plot_mesh( 805 mesh: dict = None, 806 coef_hydro: float = 99.0, 807 catchment_polygon: None | DataFrame = None, 808 ax_settings: dict | ax_properties = {}, 809 fig_settings: dict | fig_properties = {}, 810): 811 """ 812 Plot the mesh of a smash model 813 Parameters: 814 ----------- 815 mesh: a smash mesh as dictionary 816 a smash model object 817 coef_hydro: float 818 the coefficient to colorize the hydrographic network accodring the cumulative 819 surface. default is 99% so that 99% of the cell will be hidden. 820 ax_settings: dict or class ax_properties 821 object or dict with any attribute of class ax_properties 822 fig_settings: dict or class ax_properties 823 object or dict with any attribute of class fig_settings 824 """ 825 826 if mesh is not None: 827 if isinstance(mesh, dict): 828 pass 829 else: 830 raise ValueError("mesh must be a dict") 831 else: 832 raise ValueError( 833 "model or mesh are mandatory and must be a dict or a smash Model object" 834 ) 835 836 if isinstance(ax_settings, dict): 837 default_ax_settings = ax_properties( 838 title="Mesh of the Smash model", 839 xlabel="x_coords", 840 ylabel="y_coords", 841 ) 842 default_ax_settings.update(**ax_settings) 843 else: 844 default_ax_settings = ax_properties(**ax_settings.__dict__) 845 846 if isinstance(fig_settings, dict): 847 fig_settings = fig_properties(**fig_settings) 848 else: 849 fig_settings = fig_properties(**fig_settings.__dict__) 850 851 # mesh["active_cell"] 852 gauge = mesh["gauge_pos"] 853 stations = mesh["code"] 854 flow_acc = mesh["flwacc"] 855 na = mesh["active_cell"] == 0 856 857 flow_accum_bv = np.where(na, 0.0, flow_acc.data / 1000000.0) 858 surfmin = (1.0 - coef_hydro / 100.0) * np.max(flow_accum_bv) 859 mask_flow = flow_accum_bv < surfmin 860 flow_plot = np.where(mask_flow, np.nan, flow_accum_bv) 861 flow_plot = np.where(na, np.nan, flow_plot) 862 863 plt.rcParams.update(plt.rcParamsDefault) 864 fig, ax = plt.subplots() 865 fig, ax = default_ax_settings.change(figure=(fig, ax)) 866 867 bbox = geo_toolbox.get_bbox_from_smash_mesh(mesh) 868 extent = (bbox["left"], bbox["right"], bbox["bottom"], bbox["top"]) 869 870 active_cell = np.where(na, np.nan, mesh["active_cell"]) 871 cmap = ListedColormap(["lightgray"]) 872 ax.imshow(active_cell, cmap=cmap, extent=extent) 873 874 myblues = matplotlib.colormaps["Blues"] 875 cmp = ListedColormap(myblues(np.linspace(0.30, 1.0, 265))) 876 im = ax.imshow(flow_plot, cmap=cmp, extent=extent) 877 878 if catchment_polygon is not None: 879 # catchment_polygon = gpd.read_file(outlets_shapefile) 880 catchment_polygon.plot(ax=ax, facecolor="none", edgecolor="black") 881 882 # create an axes on the right side of ax. The width of cax will be 5% 883 # of ax and the padding between cax and ax will be fixed at 0.05 inch. 884 divider = make_axes_locatable(ax) 885 cax = divider.append_axes("right", size="5%", pad=0.05) 886 887 fig.colorbar( 888 im, cmap="Blues", ax=ax, label="Cumulated surface (km²)", shrink=0.75, cax=cax 889 ) 890 891 pos_y = -5 892 ha = "right" 893 for i in range(len(stations)): 894 if pos_y > 0: 895 pos_y = -10 896 else: 897 pos_y = 5 898 # pos_y=-1*pos_y 899 900 if ha == "right": 901 ha = "left" 902 pos_x = 5 903 else: 904 ha = "right" 905 # pos_x = -5 906 907 coord = geo_toolbox.rowcol_to_xy( 908 gauge[i][0], 909 gauge[i][1], 910 mesh["xmin"], 911 mesh["ymax"], 912 mesh["xres"], 913 mesh["yres"], 914 ) + np.array( 915 [ 916 mesh["dx"][gauge[i][0], gauge[i][1]] / 2, 917 -mesh["dx"][gauge[i][0], gauge[i][1]] / 2, 918 ] 919 ) 920 921 code = stations[i] 922 ax.plot(coord[0], coord[1], color="green", marker="o", markersize=6) 923 ax.annotate( 924 code, # this is the text 925 # these are the coordinates to position the label 926 (coord[0], coord[1]), 927 # textcoords="offset points", # how to position the text 928 # xytext=(pos_x, pos_y), # distance from text to points (x,y) 929 textcoords="data", # how to position the text 930 xytext=(coord[0], coord[1]), # distance from text to points (x,y) 931 ha=ha, # horizontal alignment can be left, right or center 932 color="red", 933 fontsize=10, 934 ) 935 936 fig, ax = default_ax_settings.change(figure=(fig, ax)) 937 938 fig, ax = fig_settings.change(figure=(fig, ax)) 939 940 return fig, ax
Plot the mesh of a smash model
Parameters:
mesh: a smash mesh as dictionary a smash model object coef_hydro: float the coefficient to colorize the hydrographic network accodring the cumulative surface. default is 99% so that 99% of the cell will be hidden. ax_settings: dict or class ax_properties object or dict with any attribute of class ax_properties fig_settings: dict or class ax_properties object or dict with any attribute of class fig_settings
943def plot_xy_quantile( 944 res_quantile, 945 X, 946 Y, 947 res_quantile_obs=None, 948 gauge_pos=None, 949 figure=None, 950 ax_settings: dict | ax_properties = {}, 951 fig_settings: dict | fig_properties = {}, 952 plot_settings: dict | plot_properties = {}, 953): 954 """ 955 Plot the discharges quantiles fitting at X,Y coordinates. 956 Parameters: 957 ----------- 958 res_quantile: dict 959 The result of the discharge quantile computation. 960 res_quantile_obs: dict 961 The results of the observed discharges quantile. res_quantile_obs is a dict and must be computed by the function smashbox.stats.stats.quantile_obs() 962 gauge_pos: int 963 gauge_pos is the index of gauge in the Smash mesh for which the quantile_discharge are provided to the function. 964 X: int 965 Coordinates of the pixel in the row directions (X means row) 966 Y: int 967 Coordinates of the pixel in the column directions (Y means column) 968 figure: tuple 969 input figure as (fig,ax) to add a new curve 970 ax_settings: dict or class ax_properties 971 object or dict with any attribute of class ax_properties 972 fig_settings: dict or class ax_properties 973 object or dict with any attribute of class fig_settings 974 """ 975 if isinstance(ax_settings, dict): 976 default_ax_settings = ax_properties( 977 xscale="log", 978 xlabel=f"Return period (*{res_quantile['chunk_size']} days)", 979 ylabel="Discharges (m³/s)", 980 grid=True, 981 legend=True, 982 ) 983 default_ax_settings.update(**ax_settings) 984 else: 985 default_ax_settings = ax_properties(**ax_settings.__dict__) 986 987 if isinstance(fig_settings, dict): 988 fig_settings = fig_properties(**fig_settings) 989 else: 990 fig_settings = fig_properties(**fig_settings.__dict__) 991 992 if isinstance(plot_settings, dict): 993 default_plot_settings = plot_properties(markersize=10) 994 default_plot_settings.update(**plot_settings) 995 else: 996 default_plot_settings = plot_properties(**plot_settings.__dict__) 997 998 quantile = res_quantile["Q_th"][X, Y] 999 maxima = res_quantile["maxima"][X, Y] 1000 T_emp = res_quantile["T_emp"] 1001 loc = res_quantile["fit_loc"][X, Y] 1002 scale = res_quantile["fit_scale"][X, Y] 1003 shape = res_quantile["fit_shape"][X, Y] 1004 fit = res_quantile["fit"] 1005 1006 sorted_data = np.sort(maxima) 1007 1008 plt.rcParams.update(plt.rcParamsDefault) 1009 if figure is None: 1010 fig, ax = plt.subplots() 1011 else: 1012 fig, ax = figure 1013 1014 fig, ax = default_ax_settings.change(figure=(fig, ax)) 1015 1016 if res_quantile_obs is not None and len(res_quantile_obs.keys()) > 0: 1017 if gauge_pos is None: 1018 raise ValueError( 1019 "gauge_pos is None. gauge_pos argument must be an integer corresponding to the gauge index." 1020 ) 1021 maxima_obs = res_quantile_obs["maxima"][gauge_pos, :] 1022 T_emp_obs = res_quantile_obs["Temp"][gauge_pos, :] 1023 1024 ax.plot( 1025 T_emp_obs, 1026 maxima_obs, 1027 "o", 1028 label="Observed", 1029 color="black", 1030 markersize=default_plot_settings.markersize, 1031 ) 1032 1033 ax.plot( 1034 T_emp, 1035 sorted_data, 1036 "o", 1037 label="Empirical", 1038 markersize=default_plot_settings.markersize, 1039 ) 1040 1041 ax.plot( 1042 res_quantile["T"], 1043 quantile, 1044 "x", 1045 label="Theorical", 1046 markersize=default_plot_settings.markersize, 1047 ) 1048 1049 Trange = np.linspace(1.1, np.max(res_quantile["T"]), 50) 1050 1051 if fit == "gumbel": 1052 ax.plot( 1053 Trange, 1054 [stats.quantile_gumbel(T, loc, scale) for T in Trange], 1055 "r--", 1056 label=f"{fit} fitted", 1057 lw=default_plot_settings.lw, 1058 ) 1059 1060 if fit == "gev": 1061 ax.plot( 1062 Trange, 1063 [stats.quantile_gev(T, shape, loc, scale) for T in Trange], 1064 "r--", 1065 label=f"{fit} fitted", 1066 lw=default_plot_settings.lw, 1067 ) 1068 1069 if "Umax" in res_quantile.keys() and "Umin" in res_quantile.keys(): 1070 if res_quantile["Umax"] is not None and res_quantile["Umin"] is not None: 1071 ax.plot( 1072 res_quantile["T"], 1073 res_quantile["Umax"][X, Y], 1074 "r--", 1075 label="Uncertainties (max)", 1076 color="grey", 1077 lw=default_plot_settings.lw, 1078 ) 1079 ax.plot( 1080 res_quantile["T"], 1081 res_quantile["Umin"][X, Y], 1082 "r--", 1083 label="Uncertainties (min)", 1084 color="grey", 1085 lw=default_plot_settings.lw, 1086 ) 1087 1088 fig, ax = default_ax_settings.change(figure=(fig, ax)) 1089 fig, ax = fig_settings.change(figure=(fig, ax)) 1090 1091 return fig, ax
Plot the discharges quantiles fitting at X,Y coordinates.
Parameters:
res_quantile: dict The result of the discharge quantile computation. res_quantile_obs: dict The results of the observed discharges quantile. res_quantile_obs is a dict and must be computed by the function smashbox.stats.stats.quantile_obs() gauge_pos: int gauge_pos is the index of gauge in the Smash mesh for which the quantile_discharge are provided to the function. X: int Coordinates of the pixel in the row directions (X means row) Y: int Coordinates of the pixel in the column directions (Y means column) figure: tuple input figure as (fig,ax) to add a new curve ax_settings: dict or class ax_properties object or dict with any attribute of class ax_properties fig_settings: dict or class ax_properties object or dict with any attribute of class fig_settings
1094def plot_image( 1095 matrice=np.zeros(shape=(2, 2)), 1096 bbox=None, 1097 vmin=None, 1098 vmax=None, 1099 mask=None, 1100 extend=None, 1101 catchment_polygon=None, 1102 figure=None, 1103 ax_settings: dict | ax_properties = {}, 1104 fig_settings: dict | fig_properties = {}, 1105): 1106 """ 1107 Function for plotting a matrix as an image 1108 1109 Parameters 1110 ---------- 1111 matrice : numpy array 1112 Matrix to be plotted 1113 bbox : list 1114 ["left","right","bottom","top"] bouding box to put x and y coordinates instead 1115 of the shape of the matrix 1116 vmin: real, 1117 minimum z value 1118 vmax: real, 1119 maximum z value 1120 mask: integer, matrix, shape of matice, contain 0 for pixels that should not be plotted 1121 catchment_polygon: dataframe containing some polygon to be plotted. 1122 Ideally it must contain the boundaries of the catchment as a polygon from a shp file 1123 read by geopanda. 1124 figure: tuple 1125 input figure as (fig,ax) to add a new curve 1126 ax_settings: dict or class ax_properties 1127 object or dict with any attribute of class ax_properties 1128 fig_settings: dict or class ax_properties 1129 object or dict with any attribute of class fig_settings 1130 1131 Examples 1132 ---------- 1133 smash.utils.plot_image(mesh_france['drained_area'],bbox=bbox,title="Surfaces 1134 drainƩes",xlabel="Longitude",ylabel="Latitude",zlabel="Surfaces drainƩes 1135 km^2",vmin=0.0,vmax=1000,mask=mesh_france['global_active_cell']) 1136 1137 """ 1138 1139 if isinstance(ax_settings, dict): 1140 ax_settings = ax_properties(**ax_settings) 1141 else: 1142 ax_settings = ax_properties(**ax_settings.__dict__) 1143 1144 if isinstance(fig_settings, dict): 1145 fig_settings = fig_properties(**fig_settings) 1146 else: 1147 fig_settings = fig_properties(**fig_settings.__dict__) 1148 1149 matrice = np.float32(matrice) 1150 1151 if bbox is not None: 1152 extent = [ 1153 bbox["left"], 1154 bbox["right"], 1155 bbox["bottom"], 1156 bbox["top"], 1157 ] 1158 else: 1159 extent = None 1160 1161 if mask is not None: 1162 matrice[np.where(mask == 0)] = np.nan 1163 1164 plt.rcParams.update(plt.rcParamsDefault) 1165 if figure is None: 1166 fig, ax = plt.subplots() 1167 else: 1168 fig, ax = figure 1169 1170 if vmax is None: 1171 vmax=np.max(matrice) 1172 if vmin is None: 1173 vmin=np.min(matrice) 1174 1175 fig, ax = ax_settings.change(figure=(fig, ax)) 1176 1177 im = ax.imshow(matrice, extent=extent, vmin=vmin, vmax=vmax, cmap=ax_settings.cmap) 1178 1179 if catchment_polygon is not None: 1180 catchment_polygon.plot(ax=ax, facecolor="none", edgecolor="black") 1181 1182 # create an axes on the right side of ax. The width of cax will be 5% 1183 # of ax and the padding between cax and ax will be fixed at 0.05 inch. 1184 divider = make_axes_locatable(ax) 1185 cax = divider.append_axes("right", size="5%", pad=0.05) 1186 1187 plt.colorbar(im, label=ax_settings.clabel, cax=cax) 1188 1189 fig, ax = ax_settings.change(figure=(fig, ax)) 1190 fig, ax = fig_settings.change(figure=(fig, ax)) 1191 1192 return (fig, ax)
Function for plotting a matrix as an image
Parameters
matrice : numpy array Matrix to be plotted bbox : list ["left","right","bottom","top"] bouding box to put x and y coordinates instead of the shape of the matrix vmin: real, minimum z value vmax: real, maximum z value mask: integer, matrix, shape of matice, contain 0 for pixels that should not be plotted catchment_polygon: dataframe containing some polygon to be plotted. Ideally it must contain the boundaries of the catchment as a polygon from a shp file read by geopanda. figure: tuple input figure as (fig,ax) to add a new curve ax_settings: dict or class ax_properties object or dict with any attribute of class ax_properties fig_settings: dict or class ax_properties object or dict with any attribute of class fig_settings
Examples
smash.utils.plot_image(mesh_france['drained_area'],bbox=bbox,title="Surfaces drainƩes",xlabel="Longitude",ylabel="Latitude",zlabel="Surfaces drainƩes km^2",vmin=0.0,vmax=1000,mask=mesh_france['global_active_cell'])
1195def plot_misfit( 1196 values: np.ndarray = [], 1197 names: np.ndarray = [], 1198 columns: list | None = None, 1199 misfit: str = "nse", 1200 figure: list | tuple | None = None, 1201 ax_settings: dict | ax_properties = {}, 1202 fig_settings: dict | fig_properties = {}, 1203): 1204 """ 1205 Plot the misfit criteria between the simulated and observed discharges. 1206 Parameters: 1207 ----------- 1208 values: np.ndarray 1209 The result of the discharge misfit for all outlets. 1210 names: np.ndarray 1211 Outlets name or code stored in an np.ndarray. 1212 columns: list | None 1213 Columns of the np.ndarray to plot 1214 misfit: str 1215 Criteria to plot. choice are ['nse', 'nnse', 'rmse', 'nrmse', 'se', 'kge']" 1216 figure: tuple 1217 input figure as (fig,ax) to add a new curve 1218 ax_settings: dict or class ax_properties 1219 object or dict with any attribute of class ax_properties 1220 fig_settings: dict or class ax_properties 1221 object or dict with any attribute of class fig_settings 1222 """ 1223 1224 if isinstance(ax_settings, dict): 1225 default_ax_settings = ax_properties( 1226 ylabel=f"{misfit} criteria", 1227 xlabel="Gauges stations", 1228 grid=True, 1229 legend=True, 1230 xticklabels_rotation=45, 1231 xtics_fontsize=8, 1232 ) 1233 default_ax_settings.update(**ax_settings) 1234 else: 1235 default_ax_settings = ax_properties(**ax_settings.__dict__) 1236 1237 if isinstance(fig_settings, dict): 1238 fig_settings = fig_properties(**fig_settings) 1239 else: 1240 fig_settings = fig_properties(**fig_settings.__dict__) 1241 1242 if len(names) == 0: 1243 names = np.arange(len(values)) 1244 1245 if columns is not None: 1246 values = values[columns] 1247 names = names[columns] 1248 1249 # remove nan from plot 1250 columns = list(np.isnan(values) == False) 1251 # print(columns) 1252 if len(columns) > 0: 1253 values = values[columns] 1254 names = names[columns] 1255 1256 if figure is None: 1257 fig, ax = plt.subplots() 1258 else: 1259 fig, ax = figure 1260 1261 fig, ax = default_ax_settings.change(figure=(fig, ax)) 1262 bar_container = ax.bar(names, values, color="grey", tick_label=names) 1263 1264 ax.bar_label( 1265 bar_container, 1266 fmt=lambda x: f"{x:.2f}", 1267 fontsize=default_ax_settings.barlabel_fontsize, 1268 ) 1269 1270 fig, ax = default_ax_settings.change(figure=(fig, ax)) 1271 fig, ax = fig_settings.change(figure=(fig, ax)) 1272 1273 return fig, ax
Plot the misfit criteria between the simulated and observed discharges.
Parameters:
values: np.ndarray The result of the discharge misfit for all outlets. names: np.ndarray Outlets name or code stored in an np.ndarray. columns: list | None Columns of the np.ndarray to plot misfit: str Criteria to plot. choice are ['nse', 'nnse', 'rmse', 'nrmse', 'se', 'kge']" figure: tuple input figure as (fig,ax) to add a new curve ax_settings: dict or class ax_properties object or dict with any attribute of class ax_properties fig_settings: dict or class ax_properties object or dict with any attribute of class fig_settings
1276def plot_outlet_stats( 1277 values_sim: np.ndarray | None = None, 1278 values_obs: np.ndarray | None = None, 1279 names: np.ndarray = [], 1280 columns: list | None = [], 1281 stat: str = "max", 1282 figure: list | tuple | None = None, 1283 ax_settings: dict | ax_properties = {}, 1284 fig_settings: dict | fig_properties = {}, 1285): 1286 """ 1287 Plot a statistical criteria at a given list of outlet. 1288 Parameters: 1289 ----------- 1290 values_sim: np.ndarray or None 1291 The result of the simulated stat for all outlets. 1292 values_obs: np.ndarray or None 1293 The result of the observed stat for all outlets. 1294 names: np.ndarray 1295 Outlets name or code stored in an np.ndarray. 1296 columns: list | None 1297 Columns of the np.ndarray to plot 1298 stat: str 1299 Criteria to plot. choice are ['max', 'min', 'mean', 'median', 'q20', 'q80']" 1300 figure: tuple 1301 input figure as (fig,ax) to add a new curve 1302 ax_settings: dict or class ax_properties 1303 object or dict with any attribute of class ax_properties 1304 fig_settings: dict or class ax_properties 1305 object or dict with any attribute of class fig_settings 1306 """ 1307 1308 if isinstance(ax_settings, dict): 1309 default_ax_settings = ax_properties( 1310 ylabel=f"{stat} criteria", 1311 xlabel="Gauges stations", 1312 grid=True, 1313 legend=True, 1314 xticklabels_rotation=45, 1315 xtics_fontsize=6, 1316 ) 1317 default_ax_settings.update(**ax_settings) 1318 else: 1319 default_ax_settings = ax_properties(**ax_settings.__dict__) 1320 1321 if isinstance(fig_settings, dict): 1322 fig_settings = fig_properties(**fig_settings) 1323 else: 1324 fig_settings = fig_properties(**fig_settings.__dict__) 1325 1326 if columns is not None: 1327 if values_sim is not None: 1328 values_sim = values_sim[columns] 1329 1330 if values_obs is not None: 1331 values_obs = values_obs[columns] 1332 1333 names = names[columns] 1334 1335 if np.all(values_obs == -99.0): 1336 values_obs = None 1337 1338 if values_sim is not None and values_obs is not None: 1339 if values_obs.size != values_sim.size: 1340 raise ValueError("values_sim and values_obs must have the same size !") 1341 1342 if figure is None: 1343 fig, ax = plt.subplots() 1344 else: 1345 fig, ax = figure 1346 1347 fig, ax = default_ax_settings.change(figure=(fig, ax)) 1348 1349 x = np.arange(len(names)) 1350 width = 0.25 # the width of the bars 1351 1352 multiplier = 0 1353 1354 if values_sim is not None: 1355 offset = width * multiplier 1356 ax.bar(x + offset, values_sim, width, label="obs") 1357 multiplier += 1 1358 # ax.bar_label(rects, padding=3) 1359 1360 if values_obs is not None: 1361 offset = width * multiplier 1362 ax.bar(x + offset, values_obs, width, label="sim") 1363 # ax.bar_label(rects, padding=3) 1364 # multiplier += 1 1365 1366 ax.set_xticks(x + width, names) 1367 1368 # bar_container = ax.bar(names, values, color="grey", tick_label=names) 1369 1370 # ax.bar_label( 1371 # bar_container, 1372 # fmt=lambda x: f"{x:.2f}", 1373 # fontsize=default_ax_settings.barlabel_fontsize, 1374 # ) 1375 1376 fig, ax = default_ax_settings.change(figure=(fig, ax)) 1377 fig, ax = fig_settings.change(figure=(fig, ax)) 1378 1379 return fig, ax
Plot a statistical criteria at a given list of outlet.
Parameters:
values_sim: np.ndarray or None The result of the simulated stat for all outlets. values_obs: np.ndarray or None The result of the observed stat for all outlets. names: np.ndarray Outlets name or code stored in an np.ndarray. columns: list | None Columns of the np.ndarray to plot stat: str Criteria to plot. choice are ['max', 'min', 'mean', 'median', 'q20', 'q80']" figure: tuple input figure as (fig,ax) to add a new curve ax_settings: dict or class ax_properties object or dict with any attribute of class ax_properties fig_settings: dict or class ax_properties object or dict with any attribute of class fig_settings
1382def plot_misfit_map( 1383 values: np.ndarray = [], 1384 names: np.ndarray = [], 1385 mesh=None, 1386 misfit: str = "nse", 1387 coef_hydro=99.0, 1388 catchment_polygon: None | DataFrame = None, 1389 ax_settings: dict | ax_properties = {}, 1390 fig_settings: dict | fig_properties = {}, 1391 plot_settings: dict | plot_properties = {}, 1392): 1393 """ 1394 Map plot of the misfit criteria between the simulated and observed discharges. 1395 Parameters: 1396 ----------- 1397 values: np.ndarray 1398 The result of the discharge misfit for all outlets. 1399 names: np.ndarray 1400 Outlets name or code stored in an np.ndarray. 1401 mesh: None | dict 1402 The mesh of the Smash model as dict 1403 misfit: str 1404 Criteria to plot. choice are ['nse', 'nnse', 'rmse', 'nrmse', 'se', 'kge']" 1405 figure: tuple 1406 input figure as (fig,ax) to add a new curve 1407 ax_settings: dict or class ax_properties 1408 object or dict with any attribute of class ax_properties 1409 fig_settings: dict or class ax_properties 1410 object or dict with any attribute of class fig_settings 1411 plot_settings_sim: dict or class ax_properties 1412 object or dict with any attribute of class plot_settings. 1413 """ 1414 if mesh is not None: 1415 if isinstance(mesh, dict): 1416 pass 1417 else: 1418 raise ValueError("mesh must be a dict") 1419 else: 1420 raise ValueError( 1421 "model or mesh are mandatory and must be a dict or a smash Model object" 1422 ) 1423 1424 if isinstance(ax_settings, dict): 1425 default_ax_settings = ax_properties( 1426 title=f"Map of {misfit} criteria over the domain.", 1427 xlabel="x_coords", 1428 ylabel="y_coords", 1429 cmap="turbo_r", 1430 ) 1431 default_ax_settings.update(**ax_settings) 1432 else: 1433 default_ax_settings = ax_properties(**ax_settings.__dict__) 1434 1435 if isinstance(fig_settings, dict): 1436 fig_settings = fig_properties(**fig_settings) 1437 else: 1438 fig_settings = fig_properties(**fig_settings.__dict__) 1439 1440 if isinstance(plot_settings, dict): 1441 default_plot_settings = plot_properties( 1442 marker="o", 1443 markersize=8, 1444 ) 1445 default_plot_settings.update(**plot_settings) 1446 else: 1447 default_plot_settings = plot_properties(**plot_settings.__dict__) 1448 1449 # unset attribute color, managed separatly 1450 delattr(default_plot_settings, "color") 1451 1452 gauge = mesh["gauge_pos"] 1453 stations = mesh["code"] 1454 flow_acc = mesh["flwacc"] 1455 na = mesh["active_cell"] == 0 1456 1457 bbox = geo_toolbox.get_bbox_from_smash_mesh(mesh) 1458 extent = (bbox["left"], bbox["right"], bbox["bottom"], bbox["top"]) 1459 1460 flow_accum_bv = np.where(na, 0.0, flow_acc.data) 1461 surfmin = (1.0 - coef_hydro / 100.0) * np.max(flow_accum_bv) 1462 mask_flow = flow_accum_bv < surfmin 1463 flow_plot = np.where(mask_flow, np.nan, flow_accum_bv.data) 1464 flow_plot = np.where(na, np.nan, flow_plot) 1465 1466 plt.rcParams.update(plt.rcParamsDefault) 1467 fig, ax = plt.subplots() 1468 fig, ax = default_ax_settings.change(figure=(fig, ax)) 1469 1470 active_cell = np.where(na, np.nan, mesh["active_cell"]) 1471 cmap = ListedColormap(["lightgray"]) 1472 ax.imshow(active_cell, cmap=cmap, extent=extent) 1473 1474 myblues = matplotlib.colormaps["binary"] 1475 cmp = ListedColormap(myblues(np.linspace(0.20, 1.0, 265))) 1476 im = ax.imshow(flow_plot, cmap=cmp, extent=extent) 1477 1478 if catchment_polygon is not None: 1479 # catchment_polygon = gpd.read_file(outlets_shapefile) 1480 catchment_polygon.plot(ax=ax, facecolor="none", edgecolor="black") 1481 1482 # create an axes on the right side of ax. The width of cax will be 5% 1483 # of ax and the padding between cax and ax will be fixed at 0.05 inch. 1484 divider = make_axes_locatable(ax) 1485 cax = divider.append_axes("right", size="5%", pad=0.05) 1486 1487 fig.colorbar( 1488 im, cmap="Blues", ax=ax, label="Cumulated surface (km²)", shrink=0.75, cax=cax 1489 ) 1490 1491 # define bounds for the colormap 1492 if misfit == "nse" or misfit == "nnse": 1493 vmin = 0 1494 vmax = 1 1495 elif misfit == "rmse" or misfit == "nrmse" or misfit == "se": 1496 vmin = 0 1497 vmax = np.max(values) 1498 else: 1499 vmin = np.min(values) 1500 vmax = np.max(values) 1501 1502 colormap = cm.get_cmap(default_ax_settings.cmap) 1503 cmp = ListedColormap(colormap(np.linspace(vmin, vmax, 256))) 1504 1505 ha = "right" 1506 for i in range(len(stations)): 1507 1508 if ha == "right": 1509 ha = "left" 1510 str_val = str(np.round(values[i], 2)).rjust(int(len(stations[i]))) 1511 code = f"{stations[i]}\n {str_val}" 1512 1513 else: 1514 ha = "right" 1515 str_val = str(np.round(values[i], 2)).ljust(int(len(stations[i]))) 1516 code = f"{stations[i]}\n {str_val}" 1517 1518 coord = geo_toolbox.rowcol_to_xy( 1519 gauge[i][0], 1520 gauge[i][1], 1521 mesh["xmin"], 1522 mesh["ymax"], 1523 mesh["xres"], 1524 mesh["yres"], 1525 ) 1526 1527 ax.plot( 1528 coord[0], 1529 coord[1], 1530 color=cmp(values[i]), 1531 **default_plot_settings.__dict__, 1532 ) 1533 1534 ax.annotate( 1535 code, # this is the text 1536 # these are the coordinates to position the label 1537 (coord[0], coord[1]), 1538 textcoords="data", # how to position the text 1539 xytext=(coord[0], coord[1]), # distance from text to points (x,y) 1540 ha=ha, # horizontal alignment can be left, right or center 1541 color=cmp(values[i]), 1542 fontsize=default_ax_settings.annotate_fontsize 1543 * default_ax_settings.font_ratio, 1544 ) 1545 1546 import matplotlib as mpl 1547 1548 norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax) 1549 # create an axes on the right side of ax. The width of cax will be 5% 1550 # of ax and the padding between cax and ax will be fixed at 0.05 inch. 1551 # divider = make_axes_locatable(ax) 1552 cax = divider.append_axes("right", size="5%", pad=0.5) 1553 1554 fig.colorbar( 1555 cm.ScalarMappable(norm=norm, cmap=cmp), 1556 cmap=cmp, 1557 ax=ax, 1558 cax=cax, 1559 label=misfit, 1560 shrink=0.75, 1561 location="right", 1562 ) 1563 1564 fig, ax = default_ax_settings.change(figure=(fig, ax)) 1565 1566 fig, ax = fig_settings.change(figure=(fig, ax)) 1567 1568 return fig, ax
Map plot of the misfit criteria between the simulated and observed discharges.
Parameters:
values: np.ndarray The result of the discharge misfit for all outlets. names: np.ndarray Outlets name or code stored in an np.ndarray. mesh: None | dict The mesh of the Smash model as dict misfit: str Criteria to plot. choice are ['nse', 'nnse', 'rmse', 'nrmse', 'se', 'kge']" figure: tuple input figure as (fig,ax) to add a new curve ax_settings: dict or class ax_properties object or dict with any attribute of class ax_properties fig_settings: dict or class ax_properties object or dict with any attribute of class fig_settings plot_settings_sim: dict or class ax_properties object or dict with any attribute of class plot_settings.