smashbox.plot.plot

   1# -*- coding: utf-8 -*-
   2import matplotlib.pyplot as plt
   3import numpy as np
   4import matplotlib
   5from matplotlib.colors import ListedColormap
   6import pandas as pd
   7from pandas import DataFrame
   8from smashbox.stats import stats
   9from smashbox.tools import geo_toolbox
  10import datetime
  11from smash import Model
  12
  13import matplotlib.colors as mcolors
  14from matplotlib import cm
  15import colorsys
  16from mpl_toolkits.axes_grid1 import make_axes_locatable
  17
  18import os
  19import pandas as pd
  20from smashbox.init.param import param
  21
  22from smashbox.tools import tools
  23
  24
  25class plot_properties:
  26    """Class which handle differents properties of the matplotlib plot function.
  27    All attributes can be defined by the user and the object plot_properties can be
  28     passed to any smashbox plot function.
  29    """
  30
  31    def __init__(
  32        self,
  33        ls="-",
  34        lw=1.5,
  35        marker="",
  36        markersize=4,
  37        color="black",
  38        label="",
  39    ):
  40        self.ls = ls
  41        """The style of the line (see matplotlib documentation)"""
  42        self.lw = lw
  43        """The linewidth of the line (see matplotlib documentation)"""
  44        self.marker = marker
  45        """The style of the marker (see matplotlib documentation)"""
  46        self.markersize = markersize
  47        """The size of the marker (see matplotlib documentation)"""
  48        self.color = color
  49        """The color of the line (see matplotlib documentation)"""
  50        self.label = label
  51        """The label of the line (see matplotlib documentation)"""
  52
  53    def update(self, **kwargs):
  54        """Update the class attributes using kwarg (dictionnary)"""
  55        for key, values in kwargs.items():
  56            setattr(self, key, values)
  57
  58
  59class ax_properties:
  60    """Class which handle differents properties of the matplotlib ax object
  61    (see matplotlib documentation). All attributes can be defined by the user and the
  62     object ax_properties can be passed to any smashbox plot function.
  63    """
  64
  65    def __init__(
  66        self,
  67        title: str = None,
  68        xlabel: str = None,
  69        ylabel: str = None,
  70        clabel: str = None,
  71        font_ratio: int = 1,
  72        title_fontsize: int = 12,
  73        label_fontsize: int = 10,
  74        annotate_fontsize: int = 6,
  75        grid: bool = True,
  76        xscale: str | None = None,
  77        yscale: str | None = None,
  78        legend: bool = True,
  79        legend_loc: str = None,
  80        legend_fontsize: int = 8,
  81        xtics_fontsize: int = 8,
  82        ytics_fontsize: int = 8,
  83        cmap: str | None = None,
  84        xlim: tuple | list | None = (None, None),
  85        ylim: tuple | list | None = (None, None),
  86        xticklabels_rotation: int = 0,
  87        barlabel_fontsize=6,
  88    ):
  89
  90        self.title = title
  91        """The title of the graphic"""
  92        self.xlabel = xlabel
  93        """The label of the x axis"""
  94        self.ylabel = ylabel
  95        """The label of the y axis"""
  96        self.clabel = clabel
  97        """The label of the colorbar"""
  98        self.font_ratio = font_ratio
  99        """Ratio of the global fontsize"""
 100        self.title_fontsize = title_fontsize
 101        """The fontsize of the title"""
 102        self.label_fontsize = label_fontsize
 103        """The label fontsize"""
 104        self.annotate_fontsize = annotate_fontsize
 105        """The annotation fontsize in plot"""
 106        self.grid = grid
 107        """Set the grid (boolean), default True"""
 108        self.xscale = xscale
 109        """Scale of the x axis (see matplotlib documentation)"""
 110        self.yscale = yscale
 111        """Scale of the x axis (see matplotlib documentation)"""
 112        self.legend = legend
 113        """Set the legend, boolean, default is True"""
 114        self.legend_loc = legend_loc
 115        """Localisation of the legend"""
 116        self.legend_fontsize = legend_fontsize
 117        """The fontsize of the legend"""
 118        self.xtics_fontsize = xtics_fontsize
 119        """The fontsize of the xtics"""
 120        self.ytics_fontsize = ytics_fontsize
 121        """The fontsize of the ytics"""
 122        self.cmap = cmap
 123        """The name of the used colormap"""
 124        self.xlim = xlim
 125        """The limit of the x axis, tuple or list, default is (None,None)"""
 126        self.ylim = ylim
 127        """The limit of the y axis, tuple or list, default is (None,None)"""
 128        self.xticklabels_rotation = xticklabels_rotation
 129        """Angle of the xtics labels, float, default is 0."""
 130        self.barlabel_fontsize = barlabel_fontsize * self.font_ratio
 131        "Fontsize of the bar top label"
 132
 133    def update(self, **kwargs):
 134        """Update the class attributes using kwarg (dictionnary)"""
 135        for key, values in kwargs.items():
 136            setattr(self, key, values)
 137
 138    def change(self, figure):
 139        """Apply change to the current figure `figure`
 140        Parameter:
 141        ----------
 142        figure : list or tuple
 143            list of (fig, ax), figure and ax of a matplotlib subplot.
 144        Return:
 145        -------
 146            a tuple of the modified (fig,ax)
 147        """
 148        fig, ax = figure
 149
 150        plt.rcParams.update(plt.rcParamsDefault)
 151
 152        plt.rcParams.update(
 153            {
 154                "axes.labelsize": self.label_fontsize * self.font_ratio,
 155                "axes.titlesize": self.title_fontsize * self.font_ratio,
 156                "legend.fontsize": self.legend_fontsize * self.font_ratio,
 157                "figure.titlesize": self.title_fontsize * self.font_ratio,
 158                "xtick.labelsize": self.xtics_fontsize * self.font_ratio,
 159                "ytick.labelsize": self.ytics_fontsize * self.font_ratio,
 160            }
 161        )
 162
 163        if self.title is not None:
 164
 165            ax.set_title(self.title, fontsize=self.title_fontsize * self.font_ratio)
 166
 167        if self.xlabel is not None:
 168            ax.set_xlabel(self.xlabel, fontsize=self.label_fontsize * self.font_ratio)
 169
 170        if self.ylabel is not None:
 171            ax.set_ylabel(self.ylabel, fontsize=self.label_fontsize * self.font_ratio)
 172
 173        if self.grid:
 174            ax.grid(True, which="both", linestyle="--", alpha=0.5)
 175
 176        if self.xscale is not None:
 177            ax.set_xscale(self.xscale)
 178
 179        if self.yscale is not None:
 180            ax.set_xyscale(self.yscale)
 181
 182        if self.legend:
 183            ax.legend(
 184                loc=self.legend_loc, fontsize=self.legend_fontsize * self.font_ratio
 185            )
 186
 187        if self.cmap is not None:
 188            plt.rc("image", cmap=self.cmap)
 189
 190        if self.ylim[0] is not None:
 191            ax.set_ylim(bottom=self.ylim[0])
 192
 193        if self.ylim[1] is not None:
 194            ax.set_ylim(top=self.ylim[1])
 195
 196        if self.xlim[0] is not None:
 197            ax.set_xlim(left=self.xlim[0])
 198
 199        if self.xlim[1] is not None:
 200            ax.set_xlim(right=self.xlim[1])
 201
 202        if self.xticklabels_rotation > 0:
 203            ax.set_xticklabels(
 204                ax.get_xticklabels(), rotation=self.xticklabels_rotation, ha="right"
 205            )
 206
 207        return fig, ax
 208
 209
 210class fig_properties:
 211    """Class which handle differents properties of the matplotlib fig object
 212    (see matplotlib documentation). All attributes can be defined by the user and the
 213     object ax_properties can be passed to any smashbox plot function.
 214    """
 215
 216    def __init__(
 217        self,
 218        figname=None,
 219        xsize=8,
 220        ysize=6,
 221        transparent=False,
 222        dpi=160,
 223        font_ratio=1,
 224        bbox_inches="tight",
 225    ):
 226
 227        self.figname = figname
 228        """Path to the figure name to be saved"""
 229        self.xsize = xsize
 230        """Width of the figure in inch"""
 231        self.ysize = ysize
 232        """Height of the figure in inch"""
 233        self.transparent = transparent
 234        """Use transparency when exporting the figure, default is False"""
 235        self.dpi = dpi
 236        """RƩsolution (dpi), int, default is 80"""
 237        self.font_ratio = font_ratio
 238        """Global font ratio"""
 239        self.bbox_inches = bbox_inches
 240        """Constraint of the boundingbox of each ax (see matplotlib docuentation)"""
 241
 242    def update(self, **kwargs):
 243        """Update the class attributes using kwarg (dictionnary)"""
 244
 245        for key, values in kwargs.items():
 246            setattr(self, key, values)
 247
 248    def change(self, figure):
 249        """Apply change to the current figure `figure`
 250        Parameter:
 251        ----------
 252        figure : list or tuple
 253            list of (fig, ax), figure and ax of a matplotlib subplot.
 254        Return:
 255        -------
 256            a tuple of the modified (fig,ax)
 257        """
 258
 259        fig, ax = figure
 260
 261        fig.set_figheight(self.ysize)
 262        fig.set_figwidth(self.xsize)
 263
 264        plt.rc(
 265            "font", size=plt.rcParams["font.size"] * self.font_ratio
 266        )  # controls default text sizes
 267        plt.rc(
 268            "axes", titlesize=plt.rcParams["axes.titlesize"] * self.font_ratio
 269        )  # fontsize of the axes title
 270        plt.rc(
 271            "axes", labelsize=plt.rcParams["axes.labelsize"] * self.font_ratio
 272        )  # fontsize of the x and y labels
 273        plt.rc(
 274            "xtick", labelsize=plt.rcParams["xtick.labelsize"] * self.font_ratio
 275        )  # fontsize of the tick labels
 276        plt.rc(
 277            "ytick", labelsize=plt.rcParams["ytick.labelsize"] * self.font_ratio
 278        )  # fontsize of the tick labels
 279        plt.rc(
 280            "legend", fontsize=plt.rcParams["legend.fontsize"] * self.font_ratio
 281        )  # legend fontsize
 282        plt.rc(
 283            "figure", titlesize=plt.rcParams["figure.titlesize"] * self.font_ratio
 284        )  # fontsize of the figure title
 285
 286        if self.figname is not None:
 287
 288            head_path, basename = os.path.split(self.figname)
 289
 290            if len(head_path) > 0 and not os.path.exists(head_path):
 291                os.makedirs(head_path)
 292
 293            fig.savefig(
 294                self.figname,
 295                transparent=self.transparent,
 296                dpi=self.dpi,
 297                bbox_inches=self.bbox_inches,
 298            )
 299
 300        return fig, ax
 301
 302
 303@tools.autocast_args
 304def save_figure(
 305    fig=None, figname="myfigure", xsize=8, ysize=6, transparent=False, dpi=80
 306):
 307    """
 308    Save a figure.
 309    Parameters:
 310    -----------
 311    fig: fig object returned by matplotlib.subplot
 312        the figure to save
 313    figname: str
 314        Path to the figure
 315    xsize: int
 316        width of the figure in inch
 317    ysize: int
 318        height of the figure in inch
 319    transparent: bool, default is False
 320        use transparency
 321    dpi : int
 322        resolution of the figure, default is 80
 323    """
 324    fig.set_size_inches(xsize, ysize, forward=True)
 325    fig.savefig(figname, transparent=transparent, dpi=dpi, bbox_inches="tight")
 326
 327
 328def generate_palette(base_color, n, variation="hue"):
 329    """
 330    Generate a palette of colors from a base color.
 331    Parameter:
 332    ---------
 333    base_color: str
 334        matplotlib color string
 335    n: int
 336        number of color to generate
 337    variation: 'hue' | 'brightness'
 338        how to generate the color palette, by changing the hue or the brighness of the base color.
 339    Return: a list of colors
 340    """
 341    # Convertir la couleur de base en format RGB normalisƩ (0-1)
 342    rgb = mcolors.to_rgb(base_color)
 343
 344    # Convertir en HSV
 345    h, s, v = colorsys.rgb_to_hsv(*rgb)
 346
 347    # GƩnƩrer n couleurs en modifiant la teinte ou valeur
 348    palette = []
 349    for i in range(n):
 350        new_h = h
 351        if variation == "hue":
 352            new_h = (h + i / n) % 1.0  # cycle dans le cercle chromatique
 353        new_v = v
 354        if variation == "brightness":
 355            new_v = max(0.1, min(1.0, v * (0.5 + i / (2 * n))))  # Ʃviter le noir complet
 356
 357        new_rgb = colorsys.hsv_to_rgb(new_h, s, new_v)
 358        palette.append(new_rgb)
 359
 360    return palette
 361
 362
 363@tools.autocast_args
 364def plot_chro(
 365    data: np.ndarray = np.zeros(shape=(1, 10)),
 366    t_axis: int = 1,
 367    outlets_name: list | tuple = [],
 368    columns: list | tuple = [],
 369    dt: float = 0.0,
 370    xtics: list | tuple = [],
 371    date_range: list = None,
 372    figure=None,
 373    ax_settings: dict | ax_properties = ax_properties(),
 374    fig_settings: dict | fig_properties = fig_properties(),
 375    plot_settings: dict | plot_properties = plot_properties(),
 376):
 377    """
 378    Plot a temporal chonic of values
 379    Parameters:
 380    -----------
 381    data: np.ndarray of dimenion 2.
 382        data to plot as a matrix of 2 dimension.
 383    t_axis : int
 384        the axis of the time in data, default is 1
 385    outlets_name: list
 386        the list of the outlets name
 387    columns : list
 388        the column to be plotted in t_axis direction
 389    dt: float
 390        the timestep
 391    xtics : list
 392        list of date for the xtics. The format must be automatically read by numpy.Datetime
 393    date_range: list
 394        list of [date_start, date_end, timedelta] to generate the xtics
 395    figure: tuple
 396        input figure as (fig,ax) to add a new curve
 397    ax_settings: dict or class ax_properties
 398        object or dict with any attribute of class ax_properties
 399    fig_settings: dict or class ax_properties
 400        object or dict with any attribute of class fig_settings
 401    plot_settings: dict or class plot_properties
 402        object or dict with any attribute of class plot_settings
 403
 404    """
 405
 406    if isinstance(ax_settings, dict):
 407        ax_settings = ax_properties(**ax_settings)
 408    else:
 409        ax_settings = ax_properties(**ax_settings.__dict__)
 410
 411    if isinstance(fig_settings, dict):
 412        fig_settings = fig_properties(**fig_settings)
 413    else:
 414        fig_settings = fig_properties(**fig_settings.__dict__)
 415
 416    if isinstance(plot_settings, dict):
 417        plot_settings = plot_properties(**plot_settings)
 418    else:
 419        plot_settings = plot_properties(**plot_settings.__dict__)
 420
 421    data = np.moveaxis(data, t_axis, 0)
 422
 423    if figure is not None:
 424        fig, ax = figure
 425    else:
 426        fig, ax = plt.subplots()
 427
 428    fig, ax = ax_settings.change(figure=(fig, ax))
 429
 430    if len(xtics) == 0:
 431        xtics = np.arange(0, data.shape[0])
 432        if dt > 0:
 433            xtics = xtics * dt
 434    else:
 435        for i in range(len(xtics)):
 436            xtics[i] = np.datetime64(xtics[i])
 437
 438    if date_range is not None:
 439        if len(date_range) != 3:
 440            raise ValueError(
 441                "date_range must have a length of 3: [date_start, date_end, step (s)]"
 442            )
 443        xtics = np.arange(
 444            np.datetime64(date_range[0]),
 445            np.datetime64(date_range[1] + pd.Timedelta(seconds=int(date_range[2]))),
 446            np.timedelta64(int(date_range[2]), "s"),
 447        )
 448
 449    if len(columns) > 0:
 450
 451        args = plot_settings.__dict__.copy()
 452        print(args)
 453        del args["label"]
 454        del args["color"]
 455
 456        palette = generate_palette(plot_settings.color, len(columns))
 457
 458        for i in columns:
 459            ax.plot(xtics[:], data[:, i], **args, label=outlets_name[i], color=palette[i])
 460    else:
 461        ax.plot(xtics[:], data[:, 0], **plot_settings.__dict__)
 462
 463    fig, ax = ax_settings.change(figure=(fig, ax))
 464    fig, ax = fig_settings.change(figure=(fig, ax))
 465
 466    return fig, ax
 467
 468
 469def plot_hydrograph(
 470    model: Model | None = None,
 471    columns: list | tuple = [],
 472    outlets_name: list | tuple = [],
 473    plot_rainfall: bool = True,
 474    plot_qobs: bool = True,
 475    figure: list | tuple | None = None,
 476    ax_settings: dict | ax_properties = {},
 477    fig_settings: dict | fig_properties = {},
 478    plot_settings_sim: dict | plot_properties = {},
 479    plot_settings_obs: dict | plot_properties = {},
 480):
 481    """
 482    Plot an hydrograph from a smash model
 483    Parameters:
 484    -----------
 485    model: a smash model object
 486        a smash model object
 487    outlets_name: list
 488        the list of the outlets name
 489    columns : list
 490        the column to be plotted in t_axis direction
 491    figure: tuple
 492        input figure as (fig,ax) to add a new curve
 493    ax_settings: dict or class ax_properties
 494        object or dict with any attribute of class ax_properties
 495    fig_settings: dict or class ax_properties
 496        object or dict with any attribute of class fig_settings
 497    plot_settings_sim: dict or class ax_properties
 498        object or dict with any attribute of class plot_settings.Control the simulated curve
 499    plot_settings_obs: dict or class ax_properties
 500        object or dict with any attribute of class plot_settings. Control the observed curve
 501
 502    """
 503    if model is None:
 504        raise ValueError("Input smash model object is None.")
 505
 506    if isinstance(ax_settings, dict):
 507        default_ax_settings = ax_properties(
 508            xlabel="Time", ylabel="discharges m^3/s", xtics_fontsize=10, ytics_fontsize=10
 509        )
 510        default_ax_settings.update(**ax_settings)
 511    else:
 512        default_ax_settings = ax_properties(**ax_settings.__dict__)
 513
 514    if isinstance(fig_settings, dict):
 515        fig_settings = fig_properties(**fig_settings)
 516    else:
 517        fig_settings = fig_properties(**fig_settings.__dict__)
 518
 519    if isinstance(plot_settings_sim, dict):
 520        default_plot_settings_sim = plot_properties(
 521            ls="-",
 522            lw="2",
 523            marker="",
 524            markersize=4,
 525            color="blue",
 526            label="Sim",
 527        )
 528        default_plot_settings_sim.update(**plot_settings_sim)
 529    else:
 530        default_plot_settings_sim = plot_properties(**plot_settings_sim.__dict__)
 531
 532    # default color for multi curves: same color but different line type
 533    if len(columns) >= 2:
 534        color = "blue"
 535    else:
 536        color = "black"
 537
 538    if isinstance(plot_settings_obs, dict):
 539        default_plot_settings_obs = plot_properties(
 540            ls="--",
 541            lw="1.5",
 542            marker="",
 543            markersize=4,
 544            color=color,
 545            label="Obs",
 546        )
 547        default_plot_settings_obs.update(**plot_settings_obs)
 548    else:
 549        default_plot_settings_obs = plot_properties()
 550
 551    # manage date here
 552    date_deb = datetime.datetime.fromisoformat(
 553        model.setup.start_time
 554    ) + datetime.timedelta(seconds=int(model.setup.dt))
 555    date_end = datetime.datetime.fromisoformat(model.setup.end_time)
 556    date_range = [date_deb, date_end, model.setup.dt]
 557
 558    if figure is None:
 559        if plot_rainfall:
 560            fig, (ax2, ax1) = plt.subplots(2, 1, height_ratios=[1, 4])
 561            fig.subplots_adjust(hspace=0)
 562            figure = [fig, ax1, ax2]
 563        else:
 564            fig, ax2 = plt.subplots()
 565            figure = [fig, ax1]
 566    else:
 567        if plot_rainfall:
 568            fig = figure[0]
 569            ax1 = figure[1]
 570            ax2 = figure[2]
 571        else:
 572            fig = figure[0]
 573            ax1 = figure[1]
 574
 575    fig, ax = default_ax_settings.change(figure=(fig, ax1))
 576
 577    if plot_qobs:
 578        fig, ax1 = plot_chro(
 579            np.where(model.response_data.q<0, np.nan, model.response_data.q),
 580            date_range=date_range,
 581            columns=columns,
 582            outlets_name=["obs_" + name for name in outlets_name],
 583            figure=(fig, ax1),
 584            ax_settings=default_ax_settings,
 585            fig_settings=fig_settings,
 586            plot_settings=default_plot_settings_obs,
 587        )
 588
 589    fig, ax1 = plot_chro(
 590        model.response.q,
 591        date_range=date_range,
 592        columns=columns,
 593        outlets_name=["sim_" + name for name in outlets_name],
 594        figure=(fig, ax1),
 595        ax_settings=default_ax_settings,
 596        fig_settings=fig_settings,
 597        plot_settings=default_plot_settings_sim,
 598    )
 599
 600    xtics = np.arange(
 601        np.datetime64(date_range[0]),
 602        np.datetime64(date_range[1] + datetime.timedelta(seconds=int(date_range[2]))),
 603        np.timedelta64(int(date_range[2]), "s"),
 604    )
 605
 606    axes_list = [ax1]
 607
 608    if plot_rainfall:
 609
 610        if len(columns) > 0:
 611            col = columns[0]
 612        else:
 613            col = 0
 614
 615        ax2.bar(
 616            xtics[:],
 617            model.atmos_data.mean_prcp[col, :],
 618            label="Average rainfall (mm)",
 619            width=np.timedelta64(int(date_range[2]), "s"),
 620            color="blue",
 621        )
 622
 623        ax2.invert_yaxis()
 624        ax2.grid(alpha=0.7, ls="--")
 625        ax2.get_xaxis().set_visible(False)
 626        ax2.set_ylim(bottom=1.2 * max(model.atmos_data.mean_prcp[0, :]), top=0.0)
 627        ax2.set_ylabel("Average rainfall (mm)")
 628
 629        axes_list.append(ax2)
 630
 631    fig, ax = fig_settings.change(figure=(fig, tuple(axes_list)))
 632
 633    return fig, ax
 634
 635
 636def plot_catchment_surface_error(
 637    mesh: dict = None,
 638    ax_settings: dict | ax_properties = {},
 639    fig_settings: dict | fig_properties = {},
 640):
 641    """
 642    Plot the misfit criteria between the simulated and observed discharges.
 643    Parameters:
 644    -----------
 645    values: np.ndarray
 646        The result of the discharge misfit for all outlets.
 647    names: np.ndarray
 648        Outlets name or code stored in an np.ndarray.
 649    columns: list | None
 650        Columns of the np.ndarray to plot
 651    misfit: str
 652        Criteria to plot. choice are ['nse', 'nnse', 'rmse', 'nrmse', 'se', 'kge']"
 653    figure: tuple
 654        input figure as (fig,ax) to add a new curve
 655    ax_settings: dict or class ax_properties
 656        object or dict with any attribute of class ax_properties
 657    fig_settings: dict or class ax_properties
 658        object or dict with any attribute of class fig_settings
 659    """
 660
 661    if isinstance(ax_settings, dict):
 662        default_ax_settings = ax_properties(
 663            title="Catchment surface error (Ssim-Sobs)/Sobs *100",
 664            ylabel="Surface error %",
 665            xlabel="Outlets",
 666            xticklabels_rotation=45,
 667            xtics_fontsize=6,
 668        )
 669        default_ax_settings.update(**ax_settings)
 670    else:
 671        default_ax_settings = ax_properties(**ax_settings.__dict__)
 672
 673    if isinstance(fig_settings, dict):
 674        fig_settings = fig_properties(**fig_settings)
 675    else:
 676        fig_settings = fig_properties(**fig_settings.__dict__)
 677
 678    if len(mesh["code"]) == 0:
 679        print("Cannot plot this, the mesh has no gauge !")
 680        return None, None
 681
 682    plt.rcParams.update(plt.rcParamsDefault)
 683    fig, ax = plt.subplots()
 684    fig, ax = default_ax_settings.change(figure=(fig, ax))
 685
 686    surface_error = (mesh["area_dln"] - mesh["area"]) / mesh["area"] * 100
 687
 688    fig, ax = default_ax_settings.change(figure=(fig, ax))
 689    bar_container = ax.bar(
 690        mesh["code"], surface_error, color="grey", tick_label=mesh["code"]
 691    )
 692
 693    ax.bar_label(
 694        bar_container,
 695        fmt=lambda x: f"{x:.2f}",
 696        fontsize=default_ax_settings.barlabel_fontsize,
 697    )
 698
 699    fig, ax = default_ax_settings.change(figure=(fig, ax))
 700    fig, ax = fig_settings.change(figure=(fig, ax))
 701
 702    return fig, ax
 703
 704
 705def plot_catchment_surface_consistency(
 706    mesh: dict = None,
 707    label: bool = True,
 708    ax_settings: dict | ax_properties = {},
 709    fig_settings: dict | fig_properties = {},
 710    plot_settings: dict | plot_properties = {},
 711):
 712    """
 713    Plot the modeled surface vs the observed surface
 714    Parameters:
 715    -----------
 716    mesh: dict, optional
 717        The mesh of the Smash model, defaults to None
 718    ax_settings: dict or class ax_properties
 719        object or dict with any attribute of class ax_properties
 720    fig_settings: dict or class ax_properties
 721        object or dict with any attribute of class fig_settings
 722    plot_settings: dict or class plot_properties
 723        object or dict with any attribute of class plot_settings.Control the simulated curve
 724    """
 725
 726    if isinstance(ax_settings, dict):
 727        default_ax_settings = ax_properties(
 728            title="Modeled and observed surface consistency",
 729            ylabel="Modeled surface",
 730            xlabel="Observed surface",
 731        )
 732        default_ax_settings.update(**ax_settings)
 733    else:
 734        default_ax_settings = ax_properties(**ax_settings.__dict__)
 735
 736    if isinstance(fig_settings, dict):
 737        fig_settings = fig_properties(**fig_settings)
 738    else:
 739        fig_settings = fig_properties(**fig_settings.__dict__)
 740
 741    if isinstance(plot_settings, dict):
 742        default_plot_settings = plot_properties(
 743            marker="+",
 744            markersize=12,
 745            color="blue",
 746        )
 747        default_plot_settings.update(**plot_settings)
 748    else:
 749        default_plot_settings = plot_properties(**plot_settings.__dict__)
 750
 751    if len(mesh["code"]) == 0:
 752        print("Cannot plot this, the mesh has no gauge !")
 753        return None, None
 754
 755    surface_model = mesh["area_dln"] / 1000.0**2.0
 756    surface_obs = mesh["area"] / 1000.0**2.0
 757
 758    plt.rcParams.update(plt.rcParamsDefault)
 759    fig, ax = plt.subplots()
 760    fig, ax = default_ax_settings.change(figure=(fig, ax))
 761
 762    ax.plot(
 763        surface_obs,
 764        surface_model,
 765        markersize=default_plot_settings.markersize,
 766        marker=default_plot_settings.marker,
 767        color=default_plot_settings.color,
 768        linestyle="None",
 769    )
 770    ax.plot(
 771        np.linspace(min(surface_obs), max(surface_obs), 10),
 772        np.linspace(min(surface_obs), max(surface_obs), 10),
 773        linewidth=2,
 774        color="grey",
 775    )
 776
 777    if label:
 778        ha = ("left", "right")
 779        for i, label in enumerate(mesh["code"]):
 780            ax.annotate(
 781                label,  # this is the text
 782                (
 783                    surface_obs[i],
 784                    surface_model[i],
 785                ),  # these are the coordinates to position the label
 786                textcoords="data",  # how to position the text
 787                xytext=(
 788                    surface_obs[i],
 789                    surface_model[i],
 790                ),  # distance from text to points (x,y)
 791                ha=ha[i % 2],  # horizontal alignment can be left, right or center
 792                color="red",
 793                fontsize=default_ax_settings.annotate_fontsize,
 794            )
 795
 796    ax.set(xlabel=default_ax_settings.xlabel, ylabel=default_ax_settings.ylabel)
 797
 798    fig, ax = fig_settings.change(figure=(fig, ax))
 799
 800    return fig, ax
 801
 802
 803def plot_mesh(
 804    mesh: dict = None,
 805    coef_hydro: float = 99.0,
 806    catchment_polygon: None | DataFrame = None,
 807    ax_settings: dict | ax_properties = {},
 808    fig_settings: dict | fig_properties = {},
 809):
 810    """
 811    Plot the mesh of a smash model
 812    Parameters:
 813    -----------
 814    mesh: a smash mesh as dictionary
 815        a smash model object
 816    coef_hydro: float
 817        the coefficient to colorize the hydrographic network accodring the cumulative
 818        surface. default is 99% so that 99% of the cell will be hidden.
 819    ax_settings: dict or class ax_properties
 820        object or dict with any attribute of class ax_properties
 821    fig_settings: dict or class ax_properties
 822        object or dict with any attribute of class fig_settings
 823    """
 824
 825    if mesh is not None:
 826        if isinstance(mesh, dict):
 827            pass
 828        else:
 829            raise ValueError("mesh must be a dict")
 830    else:
 831        raise ValueError(
 832            "model or mesh are mandatory and must be a dict or a smash Model object"
 833        )
 834
 835    if isinstance(ax_settings, dict):
 836        default_ax_settings = ax_properties(
 837            title="Mesh of the Smash model",
 838            xlabel="x_coords",
 839            ylabel="y_coords",
 840        )
 841        default_ax_settings.update(**ax_settings)
 842    else:
 843        default_ax_settings = ax_properties(**ax_settings.__dict__)
 844
 845    if isinstance(fig_settings, dict):
 846        fig_settings = fig_properties(**fig_settings)
 847    else:
 848        fig_settings = fig_properties(**fig_settings.__dict__)
 849
 850    # mesh["active_cell"]
 851    gauge = mesh["gauge_pos"]
 852    stations = mesh["code"]
 853    flow_acc = mesh["flwacc"]
 854    na = mesh["active_cell"] == 0
 855
 856    flow_accum_bv = np.where(na, 0.0, flow_acc.data / 1000000.0)
 857    surfmin = (1.0 - coef_hydro / 100.0) * np.max(flow_accum_bv)
 858    mask_flow = flow_accum_bv < surfmin
 859    flow_plot = np.where(mask_flow, np.nan, flow_accum_bv)
 860    flow_plot = np.where(na, np.nan, flow_plot)
 861
 862    plt.rcParams.update(plt.rcParamsDefault)
 863    fig, ax = plt.subplots()
 864    fig, ax = default_ax_settings.change(figure=(fig, ax))
 865
 866    bbox = geo_toolbox.get_bbox_from_smash_mesh(mesh)
 867    extent = (bbox["left"], bbox["right"], bbox["bottom"], bbox["top"])
 868
 869    active_cell = np.where(na, np.nan, mesh["active_cell"])
 870    cmap = ListedColormap(["lightgray"])
 871    ax.imshow(active_cell, cmap=cmap, extent=extent)
 872
 873    myblues = matplotlib.colormaps["Blues"]
 874    cmp = ListedColormap(myblues(np.linspace(0.30, 1.0, 265)))
 875    im = ax.imshow(flow_plot, cmap=cmp, extent=extent)
 876
 877    if catchment_polygon is not None:
 878        # catchment_polygon = gpd.read_file(outlets_shapefile)
 879        catchment_polygon.plot(ax=ax, facecolor="none", edgecolor="black")
 880
 881    # create an axes on the right side of ax. The width of cax will be 5%
 882    # of ax and the padding between cax and ax will be fixed at 0.05 inch.
 883    divider = make_axes_locatable(ax)
 884    cax = divider.append_axes("right", size="5%", pad=0.05)
 885
 886    fig.colorbar(
 887        im, cmap="Blues", ax=ax, label="Cumulated surface (km²)", shrink=0.75, cax=cax
 888    )
 889
 890    pos_y = -5
 891    ha = "right"
 892    for i in range(len(stations)):
 893        if pos_y > 0:
 894            pos_y = -10
 895        else:
 896            pos_y = 5
 897        # pos_y=-1*pos_y
 898
 899        if ha == "right":
 900            ha = "left"
 901            pos_x = 5
 902        else:
 903            ha = "right"
 904            # pos_x = -5
 905
 906        coord = geo_toolbox.rowcol_to_xy(
 907            gauge[i][0],
 908            gauge[i][1],
 909            mesh["xmin"],
 910            mesh["ymax"],
 911            mesh["xres"],
 912            mesh["yres"],
 913        ) + np.array(
 914            [
 915                mesh["dx"][gauge[i][0], gauge[i][1]] / 2,
 916                -mesh["dx"][gauge[i][0], gauge[i][1]] / 2,
 917            ]
 918        )
 919
 920        code = stations[i]
 921        ax.plot(coord[0], coord[1], color="green", marker="o", markersize=6)
 922        ax.annotate(
 923            code,  # this is the text
 924            # these are the coordinates to position the label
 925            (coord[0], coord[1]),
 926            # textcoords="offset points",  # how to position the text
 927            # xytext=(pos_x, pos_y),  # distance from text to points (x,y)
 928            textcoords="data",  # how to position the text
 929            xytext=(coord[0], coord[1]),  # distance from text to points (x,y)
 930            ha=ha,  # horizontal alignment can be left, right or center
 931            color="red",
 932            fontsize=10,
 933        )
 934
 935    fig, ax = default_ax_settings.change(figure=(fig, ax))
 936
 937    fig, ax = fig_settings.change(figure=(fig, ax))
 938
 939    return fig, ax
 940
 941
 942def plot_xy_quantile(
 943    res_quantile,
 944    X,
 945    Y,
 946    res_quantile_obs=None,
 947    gauge_pos=None,
 948    figure=None,
 949    ax_settings: dict | ax_properties = {},
 950    fig_settings: dict | fig_properties = {},
 951    plot_settings: dict | plot_properties = {},
 952):
 953    """
 954    Plot the discharges quantiles fitting at X,Y coordinates.
 955    Parameters:
 956    -----------
 957    res_quantile: dict
 958        The result of the discharge quantile computation.
 959    res_quantile_obs: dict
 960        The results of the observed discharges quantile. res_quantile_obs is a dict and must be computed by the function smashbox.stats.stats.quantile_obs()
 961    gauge_pos: int
 962        gauge_pos is the index of gauge in the Smash mesh for which the quantile_discharge are provided to the function.
 963    X: int
 964        Coordinates of the pixel in the row directions (X means row)
 965    Y: int
 966        Coordinates of the pixel in the column directions (Y means column)
 967    figure: tuple
 968        input figure as (fig,ax) to add a new curve
 969    ax_settings: dict or class ax_properties
 970        object or dict with any attribute of class ax_properties
 971    fig_settings: dict or class ax_properties
 972        object or dict with any attribute of class fig_settings
 973    """
 974    if isinstance(ax_settings, dict):
 975        default_ax_settings = ax_properties(
 976            xscale="log",
 977            xlabel=f"Return period (*{res_quantile['chunk_size']} days)",
 978            ylabel="Discharges (m³/s)",
 979            grid=True,
 980            legend=True,
 981        )
 982        default_ax_settings.update(**ax_settings)
 983    else:
 984        default_ax_settings = ax_properties(**ax_settings.__dict__)
 985
 986    if isinstance(fig_settings, dict):
 987        fig_settings = fig_properties(**fig_settings)
 988    else:
 989        fig_settings = fig_properties(**fig_settings.__dict__)
 990
 991    if isinstance(plot_settings, dict):
 992        default_plot_settings = plot_properties(markersize=10)
 993        default_plot_settings.update(**plot_settings)
 994    else:
 995        default_plot_settings = plot_properties(**plot_settings.__dict__)
 996
 997    quantile = res_quantile["Q_th"][X, Y]
 998    maxima = res_quantile["maxima"][X, Y]
 999    T_emp = res_quantile["T_emp"]
1000    loc = res_quantile["fit_loc"][X, Y]
1001    scale = res_quantile["fit_scale"][X, Y]
1002    shape = res_quantile["fit_shape"][X, Y]
1003    fit = res_quantile["fit"]
1004
1005    sorted_data = np.sort(maxima)
1006
1007    plt.rcParams.update(plt.rcParamsDefault)
1008    if figure is None:
1009        fig, ax = plt.subplots()
1010    else:
1011        fig, ax = figure
1012
1013    fig, ax = default_ax_settings.change(figure=(fig, ax))
1014
1015    if res_quantile_obs is not None and len(res_quantile_obs.keys()) > 0:
1016        if gauge_pos is None:
1017            raise ValueError(
1018                "gauge_pos is None. gauge_pos argument must be an integer corresponding to the gauge index."
1019            )
1020        maxima_obs = res_quantile_obs["maxima"][gauge_pos, :]
1021        T_emp_obs = res_quantile_obs["Temp"][gauge_pos, :]
1022
1023        ax.plot(
1024            T_emp_obs,
1025            maxima_obs,
1026            "o",
1027            label="Observed",
1028            color="black",
1029            markersize=default_plot_settings.markersize,
1030        )
1031
1032    ax.plot(
1033        T_emp,
1034        sorted_data,
1035        "o",
1036        label="Empirical",
1037        markersize=default_plot_settings.markersize,
1038    )
1039
1040    ax.plot(
1041        res_quantile["T"],
1042        quantile,
1043        "x",
1044        label="Theorical",
1045        markersize=default_plot_settings.markersize,
1046    )
1047
1048    Trange = np.linspace(1.1, np.max(res_quantile["T"]), 50)
1049
1050    if fit == "gumbel":
1051        ax.plot(
1052            Trange,
1053            [stats.quantile_gumbel(T, loc, scale) for T in Trange],
1054            "r--",
1055            label=f"{fit} fitted",
1056            lw=default_plot_settings.lw,
1057        )
1058
1059    if fit == "gev":
1060        ax.plot(
1061            Trange,
1062            [stats.quantile_gev(T, shape, loc, scale) for T in Trange],
1063            "r--",
1064            label=f"{fit} fitted",
1065            lw=default_plot_settings.lw,
1066        )
1067
1068    if "Umax" in res_quantile.keys() and "Umin" in res_quantile.keys():
1069        if res_quantile["Umax"] is not None and res_quantile["Umin"] is not None:
1070            ax.plot(
1071                res_quantile["T"],
1072                res_quantile["Umax"][X, Y],
1073                "r--",
1074                label="Uncertainties (max)",
1075                color="grey",
1076                lw=default_plot_settings.lw,
1077            )
1078            ax.plot(
1079                res_quantile["T"],
1080                res_quantile["Umin"][X, Y],
1081                "r--",
1082                label="Uncertainties (min)",
1083                color="grey",
1084                lw=default_plot_settings.lw,
1085            )
1086
1087    fig, ax = default_ax_settings.change(figure=(fig, ax))
1088    fig, ax = fig_settings.change(figure=(fig, ax))
1089
1090    return fig, ax
1091
1092
1093def plot_image(
1094    matrice=np.zeros(shape=(2, 2)),
1095    bbox=None,
1096    vmin=None,
1097    vmax=None,
1098    mask=None,
1099    extend=None,
1100    catchment_polygon=None,
1101    figure=None,
1102    ax_settings: dict | ax_properties = {},
1103    fig_settings: dict | fig_properties = {},
1104):
1105    """
1106    Function for plotting a matrix as an image
1107
1108    Parameters
1109    ----------
1110    matrice : numpy array
1111        Matrix to be plotted
1112    bbox : list
1113        ["left","right","bottom","top"] bouding box to put x and y coordinates instead
1114    of the shape of the matrix
1115    vmin: real,
1116        minimum z value
1117    vmax: real,
1118        maximum z value
1119    mask: integer, matrix, shape of matice, contain 0 for pixels that should not be plotted
1120    catchment_polygon: dataframe containing some polygon to be plotted.
1121    Ideally it must contain the boundaries of the catchment as a polygon from a shp file
1122    read by geopanda.
1123    figure: tuple
1124        input figure as (fig,ax) to add a new curve
1125    ax_settings: dict or class ax_properties
1126        object or dict with any attribute of class ax_properties
1127    fig_settings: dict or class ax_properties
1128        object or dict with any attribute of class fig_settings
1129
1130    Examples
1131    ----------
1132    smash.utils.plot_image(mesh_france['drained_area'],bbox=bbox,title="Surfaces
1133                           drainƩes",xlabel="Longitude",ylabel="Latitude",zlabel="Surfaces drainƩes
1134                           km^2",vmin=0.0,vmax=1000,mask=mesh_france['global_active_cell'])
1135
1136    """
1137
1138    if isinstance(ax_settings, dict):
1139        ax_settings = ax_properties(**ax_settings)
1140    else:
1141        ax_settings = ax_properties(**ax_settings.__dict__)
1142
1143    if isinstance(fig_settings, dict):
1144        fig_settings = fig_properties(**fig_settings)
1145    else:
1146        fig_settings = fig_properties(**fig_settings.__dict__)
1147
1148    matrice = np.float32(matrice)
1149
1150    if bbox is not None:
1151        extent = [
1152            bbox["left"],
1153            bbox["right"],
1154            bbox["bottom"],
1155            bbox["top"],
1156        ]
1157    else:
1158        extent = None
1159
1160    if mask is not None:
1161        matrice[np.where(mask == 0)] = np.nan
1162
1163    plt.rcParams.update(plt.rcParamsDefault)
1164    if figure is None:
1165        fig, ax = plt.subplots()
1166    else:
1167        fig, ax = figure
1168
1169    if vmax is None:
1170        vmax=np.max(matrice)
1171    if vmin is None:
1172        vmin=np.min(matrice)
1173    
1174    fig, ax = ax_settings.change(figure=(fig, ax))
1175
1176    im = ax.imshow(matrice, extent=extent, vmin=vmin, vmax=vmax, cmap=ax_settings.cmap)
1177
1178    if catchment_polygon is not None:
1179        catchment_polygon.plot(ax=ax, facecolor="none", edgecolor="black")
1180
1181    # create an axes on the right side of ax. The width of cax will be 5%
1182    # of ax and the padding between cax and ax will be fixed at 0.05 inch.
1183    divider = make_axes_locatable(ax)
1184    cax = divider.append_axes("right", size="5%", pad=0.05)
1185
1186    plt.colorbar(im, label=ax_settings.clabel, cax=cax)
1187
1188    fig, ax = ax_settings.change(figure=(fig, ax))
1189    fig, ax = fig_settings.change(figure=(fig, ax))
1190
1191    return (fig, ax)
1192
1193
1194def plot_misfit(
1195    values: np.ndarray = [],
1196    names: np.ndarray = [],
1197    columns: list | None = None,
1198    misfit: str = "nse",
1199    figure: list | tuple | None = None,
1200    ax_settings: dict | ax_properties = {},
1201    fig_settings: dict | fig_properties = {},
1202):
1203    """
1204    Plot the misfit criteria between the simulated and observed discharges.
1205    Parameters:
1206    -----------
1207    values: np.ndarray
1208        The result of the discharge misfit for all outlets.
1209    names: np.ndarray
1210        Outlets name or code stored in an np.ndarray.
1211    columns: list | None
1212        Columns of the np.ndarray to plot
1213    misfit: str
1214        Criteria to plot. choice are ['nse', 'nnse', 'rmse', 'nrmse', 'se', 'kge']"
1215    figure: tuple
1216        input figure as (fig,ax) to add a new curve
1217    ax_settings: dict or class ax_properties
1218        object or dict with any attribute of class ax_properties
1219    fig_settings: dict or class ax_properties
1220        object or dict with any attribute of class fig_settings
1221    """
1222
1223    if isinstance(ax_settings, dict):
1224        default_ax_settings = ax_properties(
1225            ylabel=f"{misfit} criteria",
1226            xlabel="Gauges stations",
1227            grid=True,
1228            legend=True,
1229            xticklabels_rotation=45,
1230            xtics_fontsize=8,
1231        )
1232        default_ax_settings.update(**ax_settings)
1233    else:
1234        default_ax_settings = ax_properties(**ax_settings.__dict__)
1235
1236    if isinstance(fig_settings, dict):
1237        fig_settings = fig_properties(**fig_settings)
1238    else:
1239        fig_settings = fig_properties(**fig_settings.__dict__)
1240
1241    if len(names) == 0:
1242        names = np.arange(len(values))
1243
1244    if columns is not None:
1245        values = values[columns]
1246        names = names[columns]
1247
1248    # remove nan from plot
1249    columns = list(np.isnan(values) == False)
1250    # print(columns)
1251    if len(columns) > 0:
1252        values = values[columns]
1253        names = names[columns]
1254
1255    if figure is None:
1256        fig, ax = plt.subplots()
1257    else:
1258        fig, ax = figure
1259
1260    fig, ax = default_ax_settings.change(figure=(fig, ax))
1261    bar_container = ax.bar(names, values, color="grey", tick_label=names)
1262
1263    ax.bar_label(
1264        bar_container,
1265        fmt=lambda x: f"{x:.2f}",
1266        fontsize=default_ax_settings.barlabel_fontsize,
1267    )
1268
1269    fig, ax = default_ax_settings.change(figure=(fig, ax))
1270    fig, ax = fig_settings.change(figure=(fig, ax))
1271
1272    return fig, ax
1273
1274
1275def plot_outlet_stats(
1276    values_sim: np.ndarray | None = None,
1277    values_obs: np.ndarray | None = None,
1278    names: np.ndarray = [],
1279    columns: list | None = [],
1280    stat: str = "max",
1281    figure: list | tuple | None = None,
1282    ax_settings: dict | ax_properties = {},
1283    fig_settings: dict | fig_properties = {},
1284):
1285    """
1286    Plot a statistical criteria at a given list of outlet.
1287    Parameters:
1288    -----------
1289    values_sim: np.ndarray or None
1290        The result of the simulated stat for all outlets.
1291    values_obs: np.ndarray or None
1292        The result of the observed stat for all outlets.
1293    names: np.ndarray
1294        Outlets name or code stored in an np.ndarray.
1295    columns: list | None
1296        Columns of the np.ndarray to plot
1297    stat: str
1298        Criteria to plot. choice are ['max', 'min', 'mean', 'median', 'q20', 'q80']"
1299    figure: tuple
1300        input figure as (fig,ax) to add a new curve
1301    ax_settings: dict or class ax_properties
1302        object or dict with any attribute of class ax_properties
1303    fig_settings: dict or class ax_properties
1304        object or dict with any attribute of class fig_settings
1305    """
1306
1307    if isinstance(ax_settings, dict):
1308        default_ax_settings = ax_properties(
1309            ylabel=f"{stat} criteria",
1310            xlabel="Gauges stations",
1311            grid=True,
1312            legend=True,
1313            xticklabels_rotation=45,
1314            xtics_fontsize=6,
1315        )
1316        default_ax_settings.update(**ax_settings)
1317    else:
1318        default_ax_settings = ax_properties(**ax_settings.__dict__)
1319
1320    if isinstance(fig_settings, dict):
1321        fig_settings = fig_properties(**fig_settings)
1322    else:
1323        fig_settings = fig_properties(**fig_settings.__dict__)
1324
1325    if columns is not None:
1326        if values_sim is not None:
1327            values_sim = values_sim[columns]
1328
1329        if values_obs is not None:
1330            values_obs = values_obs[columns]
1331
1332        names = names[columns]
1333
1334    if np.all(values_obs == -99.0):
1335        values_obs = None
1336
1337    if values_sim is not None and values_obs is not None:
1338        if values_obs.size != values_sim.size:
1339            raise ValueError("values_sim and values_obs must have the same size !")
1340
1341    if figure is None:
1342        fig, ax = plt.subplots()
1343    else:
1344        fig, ax = figure
1345
1346    fig, ax = default_ax_settings.change(figure=(fig, ax))
1347
1348    x = np.arange(len(names))
1349    width = 0.25  # the width of the bars
1350
1351    multiplier = 0
1352
1353    if values_sim is not None:
1354        offset = width * multiplier
1355        ax.bar(x + offset, values_sim, width, label="obs")
1356        multiplier += 1
1357        # ax.bar_label(rects, padding=3)
1358
1359    if values_obs is not None:
1360        offset = width * multiplier
1361        ax.bar(x + offset, values_obs, width, label="sim")
1362        # ax.bar_label(rects, padding=3)
1363        # multiplier += 1
1364
1365    ax.set_xticks(x + width, names)
1366
1367    # bar_container = ax.bar(names, values, color="grey", tick_label=names)
1368
1369    # ax.bar_label(
1370    #     bar_container,
1371    #     fmt=lambda x: f"{x:.2f}",
1372    #     fontsize=default_ax_settings.barlabel_fontsize,
1373    # )
1374
1375    fig, ax = default_ax_settings.change(figure=(fig, ax))
1376    fig, ax = fig_settings.change(figure=(fig, ax))
1377
1378    return fig, ax
1379
1380
1381def plot_misfit_map(
1382    values: np.ndarray = [],
1383    names: np.ndarray = [],
1384    mesh=None,
1385    misfit: str = "nse",
1386    coef_hydro=99.0,
1387    catchment_polygon: None | DataFrame = None,
1388    ax_settings: dict | ax_properties = {},
1389    fig_settings: dict | fig_properties = {},
1390    plot_settings: dict | plot_properties = {},
1391):
1392    """
1393    Map plot of the misfit criteria between the simulated and observed discharges.
1394    Parameters:
1395    -----------
1396    values: np.ndarray
1397        The result of the discharge misfit for all outlets.
1398    names: np.ndarray
1399        Outlets name or code stored in an np.ndarray.
1400    mesh: None | dict
1401        The mesh of the Smash model as dict
1402    misfit: str
1403        Criteria to plot. choice are ['nse', 'nnse', 'rmse', 'nrmse', 'se', 'kge']"
1404    figure: tuple
1405        input figure as (fig,ax) to add a new curve
1406    ax_settings: dict or class ax_properties
1407        object or dict with any attribute of class ax_properties
1408    fig_settings: dict or class ax_properties
1409        object or dict with any attribute of class fig_settings
1410    plot_settings_sim: dict or class ax_properties
1411        object or dict with any attribute of class plot_settings.
1412    """
1413    if mesh is not None:
1414        if isinstance(mesh, dict):
1415            pass
1416        else:
1417            raise ValueError("mesh must be a dict")
1418    else:
1419        raise ValueError(
1420            "model or mesh are mandatory and must be a dict or a smash Model object"
1421        )
1422
1423    if isinstance(ax_settings, dict):
1424        default_ax_settings = ax_properties(
1425            title=f"Map of {misfit} criteria over the domain.",
1426            xlabel="x_coords",
1427            ylabel="y_coords",
1428            cmap="turbo_r",
1429        )
1430        default_ax_settings.update(**ax_settings)
1431    else:
1432        default_ax_settings = ax_properties(**ax_settings.__dict__)
1433
1434    if isinstance(fig_settings, dict):
1435        fig_settings = fig_properties(**fig_settings)
1436    else:
1437        fig_settings = fig_properties(**fig_settings.__dict__)
1438
1439    if isinstance(plot_settings, dict):
1440        default_plot_settings = plot_properties(
1441            marker="o",
1442            markersize=8,
1443        )
1444        default_plot_settings.update(**plot_settings)
1445    else:
1446        default_plot_settings = plot_properties(**plot_settings.__dict__)
1447
1448    # unset attribute color, managed separatly
1449    delattr(default_plot_settings, "color")
1450
1451    gauge = mesh["gauge_pos"]
1452    stations = mesh["code"]
1453    flow_acc = mesh["flwacc"]
1454    na = mesh["active_cell"] == 0
1455
1456    bbox = geo_toolbox.get_bbox_from_smash_mesh(mesh)
1457    extent = (bbox["left"], bbox["right"], bbox["bottom"], bbox["top"])
1458
1459    flow_accum_bv = np.where(na, 0.0, flow_acc.data)
1460    surfmin = (1.0 - coef_hydro / 100.0) * np.max(flow_accum_bv)
1461    mask_flow = flow_accum_bv < surfmin
1462    flow_plot = np.where(mask_flow, np.nan, flow_accum_bv.data)
1463    flow_plot = np.where(na, np.nan, flow_plot)
1464
1465    plt.rcParams.update(plt.rcParamsDefault)
1466    fig, ax = plt.subplots()
1467    fig, ax = default_ax_settings.change(figure=(fig, ax))
1468
1469    active_cell = np.where(na, np.nan, mesh["active_cell"])
1470    cmap = ListedColormap(["lightgray"])
1471    ax.imshow(active_cell, cmap=cmap, extent=extent)
1472
1473    myblues = matplotlib.colormaps["binary"]
1474    cmp = ListedColormap(myblues(np.linspace(0.20, 1.0, 265)))
1475    im = ax.imshow(flow_plot, cmap=cmp, extent=extent)
1476
1477    if catchment_polygon is not None:
1478        # catchment_polygon = gpd.read_file(outlets_shapefile)
1479        catchment_polygon.plot(ax=ax, facecolor="none", edgecolor="black")
1480
1481    # create an axes on the right side of ax. The width of cax will be 5%
1482    # of ax and the padding between cax and ax will be fixed at 0.05 inch.
1483    divider = make_axes_locatable(ax)
1484    cax = divider.append_axes("right", size="5%", pad=0.05)
1485
1486    fig.colorbar(
1487        im, cmap="Blues", ax=ax, label="Cumulated surface (km²)", shrink=0.75, cax=cax
1488    )
1489
1490    # define bounds for the colormap
1491    if misfit == "nse" or misfit == "nnse":
1492        vmin = 0
1493        vmax = 1
1494    elif misfit == "rmse" or misfit == "nrmse" or misfit == "se":
1495        vmin = 0
1496        vmax = np.max(values)
1497    else:
1498        vmin = np.min(values)
1499        vmax = np.max(values)
1500
1501    colormap = cm.get_cmap(default_ax_settings.cmap)
1502    cmp = ListedColormap(colormap(np.linspace(vmin, vmax, 256)))
1503
1504    ha = "right"
1505    for i in range(len(stations)):
1506
1507        if ha == "right":
1508            ha = "left"
1509            str_val = str(np.round(values[i], 2)).rjust(int(len(stations[i])))
1510            code = f"{stations[i]}\n {str_val}"
1511
1512        else:
1513            ha = "right"
1514            str_val = str(np.round(values[i], 2)).ljust(int(len(stations[i])))
1515            code = f"{stations[i]}\n {str_val}"
1516
1517        coord = geo_toolbox.rowcol_to_xy(
1518            gauge[i][0],
1519            gauge[i][1],
1520            mesh["xmin"],
1521            mesh["ymax"],
1522            mesh["xres"],
1523            mesh["yres"],
1524        )
1525
1526        ax.plot(
1527            coord[0],
1528            coord[1],
1529            color=cmp(values[i]),
1530            **default_plot_settings.__dict__,
1531        )
1532
1533        ax.annotate(
1534            code,  # this is the text
1535            # these are the coordinates to position the label
1536            (coord[0], coord[1]),
1537            textcoords="data",  # how to position the text
1538            xytext=(coord[0], coord[1]),  # distance from text to points (x,y)
1539            ha=ha,  # horizontal alignment can be left, right or center
1540            color=cmp(values[i]),
1541            fontsize=default_ax_settings.annotate_fontsize
1542            * default_ax_settings.font_ratio,
1543        )
1544
1545    import matplotlib as mpl
1546
1547    norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
1548    # create an axes on the right side of ax. The width of cax will be 5%
1549    # of ax and the padding between cax and ax will be fixed at 0.05 inch.
1550    # divider = make_axes_locatable(ax)
1551    cax = divider.append_axes("right", size="5%", pad=0.5)
1552
1553    fig.colorbar(
1554        cm.ScalarMappable(norm=norm, cmap=cmp),
1555        cmap=cmp,
1556        ax=ax,
1557        cax=cax,
1558        label=misfit,
1559        shrink=0.75,
1560        location="right",
1561    )
1562
1563    fig, ax = default_ax_settings.change(figure=(fig, ax))
1564
1565    fig, ax = fig_settings.change(figure=(fig, ax))
1566
1567    return fig, ax
1568
1569
1570# def _ax_settings(
1571#     figure,
1572#     title: str = None,
1573#     xlabel: str = None,
1574#     ylabel: str = None,
1575#     clabel: str = None,
1576#     font_ratio: int = 1,
1577#     title_fontsize: int = 12,
1578#     label_fontsize: int = 10,
1579#     grid: bool = True,
1580#     xscale: str | None = None,
1581#     yscale: str | None = None,
1582#     legend: bool = True,
1583#     legend_loc: str = None,
1584#     legend_fontsize: int = 8,
1585#     xtics_fontsize: int = 8,
1586#     ytics_fontsize: int = 8,
1587#     cmap: str | None = None,
1588#     xlim: tuple | list | None = (None, None),
1589#     ylim: tuple | list | None = (None, None),
1590# ):
1591
1592#     fig, ax = figure
1593
1594#     plt.rcParams.update(plt.rcParamsDefault)
1595
1596#     if title is not None:
1597#         ax.set_title(title, fontsize=title_fontsize * font_ratio)
1598
1599#     if xlabel is not None:
1600#         ax.set_xlabel(xlabel)
1601
1602#     if ylabel is not None:
1603#         ax.set_ylabel(ylabel)
1604
1605#     if grid:
1606#         ax.grid(True, which="both", linestyle="--", alpha=0.5)
1607
1608#     if xscale is not None:
1609#         ax.set_xscale(xscale)
1610
1611#     if yscale is not None:
1612#         ax.set_xyscale(yscale)
1613
1614#     if legend:
1615#         ax.legend(loc=legend_loc)
1616
1617#     if cmap is not None:
1618#         plt.rc("image", cmap=cmap)
1619
1620#     if ylim[0] != None:
1621#         ax.set_ylim(bottom=ylim[0])
1622
1623#     if ylim[1] != None:
1624#         ax.set_ylim(top=ylim[1])
1625
1626#     if xlim[0] != None:
1627#         ax.set_xlim(left=xlim[0])
1628
1629#     if xlim[1] != None:
1630#         ax.set_xlim(right=xlim[1])
1631
1632#     plt.rcParams.update(
1633#         {
1634#             "axes.labelsize": label_fontsize * font_ratio,
1635#             "axes.titlesize": title_fontsize * font_ratio,
1636#             "legend.fontsize": legend_fontsize * font_ratio,
1637#             "figure.titlesize": title_fontsize * font_ratio,
1638#             "xtick.labelsize": xtics_fontsize * font_ratio,
1639#             "ytick.labelsize": ytics_fontsize * font_ratio,
1640#         }
1641#     )
1642
1643#     return fig, ax
1644
1645
1646# def _fig_settings(
1647#     figure,
1648#     figname=None,
1649#     xsize=8,
1650#     ysize=6,
1651#     transparent=False,
1652#     dpi=80,
1653#     font_ratio=1,
1654#     bbox_inches="tight",
1655# ):
1656
1657#     fig, ax = figure
1658
1659#     fig.set_figheight(ysize)
1660#     fig.set_figwidth(xsize)
1661
1662#     plt.rc(
1663#         "font", size=plt.rcParams["font.size"] * font_ratio
1664#     )  # controls default text sizes
1665#     plt.rc(
1666#         "axes", titlesize=plt.rcParams["axes.titlesize"] * font_ratio
1667#     )  # fontsize of the axes title
1668#     plt.rc(
1669#         "axes", labelsize=plt.rcParams["axes.labelsize"] * font_ratio
1670#     )  # fontsize of the x and y labels
1671#     plt.rc(
1672#         "xtick", labelsize=plt.rcParams["xtick.labelsize"] * font_ratio
1673#     )  # fontsize of the tick labels
1674#     plt.rc(
1675#         "ytick", labelsize=plt.rcParams["ytick.labelsize"] * font_ratio
1676#     )  # fontsize of the tick labels
1677#     plt.rc(
1678#         "legend", fontsize=plt.rcParams["legend.fontsize"] * font_ratio
1679#     )  # legend fontsize
1680#     plt.rc(
1681#         "figure", titlesize=plt.rcParams["figure.titlesize"] * font_ratio
1682#     )  # fontsize of the figure title
1683
1684#     if figname is not None:
1685#         fig.savefig(figname, transparent=transparent, dpi=dpi, bbox_inches=bbox_inches)
1686
1687#     return fig, ax
class plot_properties:
26class plot_properties:
27    """Class which handle differents properties of the matplotlib plot function.
28    All attributes can be defined by the user and the object plot_properties can be
29     passed to any smashbox plot function.
30    """
31
32    def __init__(
33        self,
34        ls="-",
35        lw=1.5,
36        marker="",
37        markersize=4,
38        color="black",
39        label="",
40    ):
41        self.ls = ls
42        """The style of the line (see matplotlib documentation)"""
43        self.lw = lw
44        """The linewidth of the line (see matplotlib documentation)"""
45        self.marker = marker
46        """The style of the marker (see matplotlib documentation)"""
47        self.markersize = markersize
48        """The size of the marker (see matplotlib documentation)"""
49        self.color = color
50        """The color of the line (see matplotlib documentation)"""
51        self.label = label
52        """The label of the line (see matplotlib documentation)"""
53
54    def update(self, **kwargs):
55        """Update the class attributes using kwarg (dictionnary)"""
56        for key, values in kwargs.items():
57            setattr(self, key, values)

Class which handle differents properties of the matplotlib plot function. All attributes can be defined by the user and the object plot_properties can be passed to any smashbox plot function.

plot_properties(ls='-', lw=1.5, marker='', markersize=4, color='black', label='')
32    def __init__(
33        self,
34        ls="-",
35        lw=1.5,
36        marker="",
37        markersize=4,
38        color="black",
39        label="",
40    ):
41        self.ls = ls
42        """The style of the line (see matplotlib documentation)"""
43        self.lw = lw
44        """The linewidth of the line (see matplotlib documentation)"""
45        self.marker = marker
46        """The style of the marker (see matplotlib documentation)"""
47        self.markersize = markersize
48        """The size of the marker (see matplotlib documentation)"""
49        self.color = color
50        """The color of the line (see matplotlib documentation)"""
51        self.label = label
52        """The label of the line (see matplotlib documentation)"""
ls

The style of the line (see matplotlib documentation)

lw

The linewidth of the line (see matplotlib documentation)

marker

The style of the marker (see matplotlib documentation)

markersize

The size of the marker (see matplotlib documentation)

color

The color of the line (see matplotlib documentation)

label

The label of the line (see matplotlib documentation)

def update(self, **kwargs):
54    def update(self, **kwargs):
55        """Update the class attributes using kwarg (dictionnary)"""
56        for key, values in kwargs.items():
57            setattr(self, key, values)

Update the class attributes using kwarg (dictionnary)

class ax_properties:
 60class ax_properties:
 61    """Class which handle differents properties of the matplotlib ax object
 62    (see matplotlib documentation). All attributes can be defined by the user and the
 63     object ax_properties can be passed to any smashbox plot function.
 64    """
 65
 66    def __init__(
 67        self,
 68        title: str = None,
 69        xlabel: str = None,
 70        ylabel: str = None,
 71        clabel: str = None,
 72        font_ratio: int = 1,
 73        title_fontsize: int = 12,
 74        label_fontsize: int = 10,
 75        annotate_fontsize: int = 6,
 76        grid: bool = True,
 77        xscale: str | None = None,
 78        yscale: str | None = None,
 79        legend: bool = True,
 80        legend_loc: str = None,
 81        legend_fontsize: int = 8,
 82        xtics_fontsize: int = 8,
 83        ytics_fontsize: int = 8,
 84        cmap: str | None = None,
 85        xlim: tuple | list | None = (None, None),
 86        ylim: tuple | list | None = (None, None),
 87        xticklabels_rotation: int = 0,
 88        barlabel_fontsize=6,
 89    ):
 90
 91        self.title = title
 92        """The title of the graphic"""
 93        self.xlabel = xlabel
 94        """The label of the x axis"""
 95        self.ylabel = ylabel
 96        """The label of the y axis"""
 97        self.clabel = clabel
 98        """The label of the colorbar"""
 99        self.font_ratio = font_ratio
100        """Ratio of the global fontsize"""
101        self.title_fontsize = title_fontsize
102        """The fontsize of the title"""
103        self.label_fontsize = label_fontsize
104        """The label fontsize"""
105        self.annotate_fontsize = annotate_fontsize
106        """The annotation fontsize in plot"""
107        self.grid = grid
108        """Set the grid (boolean), default True"""
109        self.xscale = xscale
110        """Scale of the x axis (see matplotlib documentation)"""
111        self.yscale = yscale
112        """Scale of the x axis (see matplotlib documentation)"""
113        self.legend = legend
114        """Set the legend, boolean, default is True"""
115        self.legend_loc = legend_loc
116        """Localisation of the legend"""
117        self.legend_fontsize = legend_fontsize
118        """The fontsize of the legend"""
119        self.xtics_fontsize = xtics_fontsize
120        """The fontsize of the xtics"""
121        self.ytics_fontsize = ytics_fontsize
122        """The fontsize of the ytics"""
123        self.cmap = cmap
124        """The name of the used colormap"""
125        self.xlim = xlim
126        """The limit of the x axis, tuple or list, default is (None,None)"""
127        self.ylim = ylim
128        """The limit of the y axis, tuple or list, default is (None,None)"""
129        self.xticklabels_rotation = xticklabels_rotation
130        """Angle of the xtics labels, float, default is 0."""
131        self.barlabel_fontsize = barlabel_fontsize * self.font_ratio
132        "Fontsize of the bar top label"
133
134    def update(self, **kwargs):
135        """Update the class attributes using kwarg (dictionnary)"""
136        for key, values in kwargs.items():
137            setattr(self, key, values)
138
139    def change(self, figure):
140        """Apply change to the current figure `figure`
141        Parameter:
142        ----------
143        figure : list or tuple
144            list of (fig, ax), figure and ax of a matplotlib subplot.
145        Return:
146        -------
147            a tuple of the modified (fig,ax)
148        """
149        fig, ax = figure
150
151        plt.rcParams.update(plt.rcParamsDefault)
152
153        plt.rcParams.update(
154            {
155                "axes.labelsize": self.label_fontsize * self.font_ratio,
156                "axes.titlesize": self.title_fontsize * self.font_ratio,
157                "legend.fontsize": self.legend_fontsize * self.font_ratio,
158                "figure.titlesize": self.title_fontsize * self.font_ratio,
159                "xtick.labelsize": self.xtics_fontsize * self.font_ratio,
160                "ytick.labelsize": self.ytics_fontsize * self.font_ratio,
161            }
162        )
163
164        if self.title is not None:
165
166            ax.set_title(self.title, fontsize=self.title_fontsize * self.font_ratio)
167
168        if self.xlabel is not None:
169            ax.set_xlabel(self.xlabel, fontsize=self.label_fontsize * self.font_ratio)
170
171        if self.ylabel is not None:
172            ax.set_ylabel(self.ylabel, fontsize=self.label_fontsize * self.font_ratio)
173
174        if self.grid:
175            ax.grid(True, which="both", linestyle="--", alpha=0.5)
176
177        if self.xscale is not None:
178            ax.set_xscale(self.xscale)
179
180        if self.yscale is not None:
181            ax.set_xyscale(self.yscale)
182
183        if self.legend:
184            ax.legend(
185                loc=self.legend_loc, fontsize=self.legend_fontsize * self.font_ratio
186            )
187
188        if self.cmap is not None:
189            plt.rc("image", cmap=self.cmap)
190
191        if self.ylim[0] is not None:
192            ax.set_ylim(bottom=self.ylim[0])
193
194        if self.ylim[1] is not None:
195            ax.set_ylim(top=self.ylim[1])
196
197        if self.xlim[0] is not None:
198            ax.set_xlim(left=self.xlim[0])
199
200        if self.xlim[1] is not None:
201            ax.set_xlim(right=self.xlim[1])
202
203        if self.xticklabels_rotation > 0:
204            ax.set_xticklabels(
205                ax.get_xticklabels(), rotation=self.xticklabels_rotation, ha="right"
206            )
207
208        return fig, ax

Class which handle differents properties of the matplotlib ax object (see matplotlib documentation). All attributes can be defined by the user and the object ax_properties can be passed to any smashbox plot function.

ax_properties( title: str = None, xlabel: str = None, ylabel: str = None, clabel: str = None, font_ratio: int = 1, title_fontsize: int = 12, label_fontsize: int = 10, annotate_fontsize: int = 6, grid: bool = True, xscale: str | None = None, yscale: str | None = None, legend: bool = True, legend_loc: str = None, legend_fontsize: int = 8, xtics_fontsize: int = 8, ytics_fontsize: int = 8, cmap: str | None = None, xlim: tuple | list | None = (None, None), ylim: tuple | list | None = (None, None), xticklabels_rotation: int = 0, barlabel_fontsize=6)
 66    def __init__(
 67        self,
 68        title: str = None,
 69        xlabel: str = None,
 70        ylabel: str = None,
 71        clabel: str = None,
 72        font_ratio: int = 1,
 73        title_fontsize: int = 12,
 74        label_fontsize: int = 10,
 75        annotate_fontsize: int = 6,
 76        grid: bool = True,
 77        xscale: str | None = None,
 78        yscale: str | None = None,
 79        legend: bool = True,
 80        legend_loc: str = None,
 81        legend_fontsize: int = 8,
 82        xtics_fontsize: int = 8,
 83        ytics_fontsize: int = 8,
 84        cmap: str | None = None,
 85        xlim: tuple | list | None = (None, None),
 86        ylim: tuple | list | None = (None, None),
 87        xticklabels_rotation: int = 0,
 88        barlabel_fontsize=6,
 89    ):
 90
 91        self.title = title
 92        """The title of the graphic"""
 93        self.xlabel = xlabel
 94        """The label of the x axis"""
 95        self.ylabel = ylabel
 96        """The label of the y axis"""
 97        self.clabel = clabel
 98        """The label of the colorbar"""
 99        self.font_ratio = font_ratio
100        """Ratio of the global fontsize"""
101        self.title_fontsize = title_fontsize
102        """The fontsize of the title"""
103        self.label_fontsize = label_fontsize
104        """The label fontsize"""
105        self.annotate_fontsize = annotate_fontsize
106        """The annotation fontsize in plot"""
107        self.grid = grid
108        """Set the grid (boolean), default True"""
109        self.xscale = xscale
110        """Scale of the x axis (see matplotlib documentation)"""
111        self.yscale = yscale
112        """Scale of the x axis (see matplotlib documentation)"""
113        self.legend = legend
114        """Set the legend, boolean, default is True"""
115        self.legend_loc = legend_loc
116        """Localisation of the legend"""
117        self.legend_fontsize = legend_fontsize
118        """The fontsize of the legend"""
119        self.xtics_fontsize = xtics_fontsize
120        """The fontsize of the xtics"""
121        self.ytics_fontsize = ytics_fontsize
122        """The fontsize of the ytics"""
123        self.cmap = cmap
124        """The name of the used colormap"""
125        self.xlim = xlim
126        """The limit of the x axis, tuple or list, default is (None,None)"""
127        self.ylim = ylim
128        """The limit of the y axis, tuple or list, default is (None,None)"""
129        self.xticklabels_rotation = xticklabels_rotation
130        """Angle of the xtics labels, float, default is 0."""
131        self.barlabel_fontsize = barlabel_fontsize * self.font_ratio
132        "Fontsize of the bar top label"
title

The title of the graphic

xlabel

The label of the x axis

ylabel

The label of the y axis

clabel

The label of the colorbar

font_ratio

Ratio of the global fontsize

title_fontsize

The fontsize of the title

label_fontsize

The label fontsize

annotate_fontsize

The annotation fontsize in plot

grid

Set the grid (boolean), default True

xscale

Scale of the x axis (see matplotlib documentation)

yscale

Scale of the x axis (see matplotlib documentation)

legend

Set the legend, boolean, default is True

legend_loc

Localisation of the legend

legend_fontsize

The fontsize of the legend

xtics_fontsize

The fontsize of the xtics

ytics_fontsize

The fontsize of the ytics

cmap

The name of the used colormap

xlim

The limit of the x axis, tuple or list, default is (None,None)

ylim

The limit of the y axis, tuple or list, default is (None,None)

xticklabels_rotation

Angle of the xtics labels, float, default is 0.

barlabel_fontsize

Fontsize of the bar top label

def update(self, **kwargs):
134    def update(self, **kwargs):
135        """Update the class attributes using kwarg (dictionnary)"""
136        for key, values in kwargs.items():
137            setattr(self, key, values)

Update the class attributes using kwarg (dictionnary)

def change(self, figure):
139    def change(self, figure):
140        """Apply change to the current figure `figure`
141        Parameter:
142        ----------
143        figure : list or tuple
144            list of (fig, ax), figure and ax of a matplotlib subplot.
145        Return:
146        -------
147            a tuple of the modified (fig,ax)
148        """
149        fig, ax = figure
150
151        plt.rcParams.update(plt.rcParamsDefault)
152
153        plt.rcParams.update(
154            {
155                "axes.labelsize": self.label_fontsize * self.font_ratio,
156                "axes.titlesize": self.title_fontsize * self.font_ratio,
157                "legend.fontsize": self.legend_fontsize * self.font_ratio,
158                "figure.titlesize": self.title_fontsize * self.font_ratio,
159                "xtick.labelsize": self.xtics_fontsize * self.font_ratio,
160                "ytick.labelsize": self.ytics_fontsize * self.font_ratio,
161            }
162        )
163
164        if self.title is not None:
165
166            ax.set_title(self.title, fontsize=self.title_fontsize * self.font_ratio)
167
168        if self.xlabel is not None:
169            ax.set_xlabel(self.xlabel, fontsize=self.label_fontsize * self.font_ratio)
170
171        if self.ylabel is not None:
172            ax.set_ylabel(self.ylabel, fontsize=self.label_fontsize * self.font_ratio)
173
174        if self.grid:
175            ax.grid(True, which="both", linestyle="--", alpha=0.5)
176
177        if self.xscale is not None:
178            ax.set_xscale(self.xscale)
179
180        if self.yscale is not None:
181            ax.set_xyscale(self.yscale)
182
183        if self.legend:
184            ax.legend(
185                loc=self.legend_loc, fontsize=self.legend_fontsize * self.font_ratio
186            )
187
188        if self.cmap is not None:
189            plt.rc("image", cmap=self.cmap)
190
191        if self.ylim[0] is not None:
192            ax.set_ylim(bottom=self.ylim[0])
193
194        if self.ylim[1] is not None:
195            ax.set_ylim(top=self.ylim[1])
196
197        if self.xlim[0] is not None:
198            ax.set_xlim(left=self.xlim[0])
199
200        if self.xlim[1] is not None:
201            ax.set_xlim(right=self.xlim[1])
202
203        if self.xticklabels_rotation > 0:
204            ax.set_xticklabels(
205                ax.get_xticklabels(), rotation=self.xticklabels_rotation, ha="right"
206            )
207
208        return fig, ax

Apply change to the current figure figure

Parameter:

figure : list or tuple list of (fig, ax), figure and ax of a matplotlib subplot.

Return:

a tuple of the modified (fig,ax)
class fig_properties:
211class fig_properties:
212    """Class which handle differents properties of the matplotlib fig object
213    (see matplotlib documentation). All attributes can be defined by the user and the
214     object ax_properties can be passed to any smashbox plot function.
215    """
216
217    def __init__(
218        self,
219        figname=None,
220        xsize=8,
221        ysize=6,
222        transparent=False,
223        dpi=160,
224        font_ratio=1,
225        bbox_inches="tight",
226    ):
227
228        self.figname = figname
229        """Path to the figure name to be saved"""
230        self.xsize = xsize
231        """Width of the figure in inch"""
232        self.ysize = ysize
233        """Height of the figure in inch"""
234        self.transparent = transparent
235        """Use transparency when exporting the figure, default is False"""
236        self.dpi = dpi
237        """RƩsolution (dpi), int, default is 80"""
238        self.font_ratio = font_ratio
239        """Global font ratio"""
240        self.bbox_inches = bbox_inches
241        """Constraint of the boundingbox of each ax (see matplotlib docuentation)"""
242
243    def update(self, **kwargs):
244        """Update the class attributes using kwarg (dictionnary)"""
245
246        for key, values in kwargs.items():
247            setattr(self, key, values)
248
249    def change(self, figure):
250        """Apply change to the current figure `figure`
251        Parameter:
252        ----------
253        figure : list or tuple
254            list of (fig, ax), figure and ax of a matplotlib subplot.
255        Return:
256        -------
257            a tuple of the modified (fig,ax)
258        """
259
260        fig, ax = figure
261
262        fig.set_figheight(self.ysize)
263        fig.set_figwidth(self.xsize)
264
265        plt.rc(
266            "font", size=plt.rcParams["font.size"] * self.font_ratio
267        )  # controls default text sizes
268        plt.rc(
269            "axes", titlesize=plt.rcParams["axes.titlesize"] * self.font_ratio
270        )  # fontsize of the axes title
271        plt.rc(
272            "axes", labelsize=plt.rcParams["axes.labelsize"] * self.font_ratio
273        )  # fontsize of the x and y labels
274        plt.rc(
275            "xtick", labelsize=plt.rcParams["xtick.labelsize"] * self.font_ratio
276        )  # fontsize of the tick labels
277        plt.rc(
278            "ytick", labelsize=plt.rcParams["ytick.labelsize"] * self.font_ratio
279        )  # fontsize of the tick labels
280        plt.rc(
281            "legend", fontsize=plt.rcParams["legend.fontsize"] * self.font_ratio
282        )  # legend fontsize
283        plt.rc(
284            "figure", titlesize=plt.rcParams["figure.titlesize"] * self.font_ratio
285        )  # fontsize of the figure title
286
287        if self.figname is not None:
288
289            head_path, basename = os.path.split(self.figname)
290
291            if len(head_path) > 0 and not os.path.exists(head_path):
292                os.makedirs(head_path)
293
294            fig.savefig(
295                self.figname,
296                transparent=self.transparent,
297                dpi=self.dpi,
298                bbox_inches=self.bbox_inches,
299            )
300
301        return fig, ax

Class which handle differents properties of the matplotlib fig object (see matplotlib documentation). All attributes can be defined by the user and the object ax_properties can be passed to any smashbox plot function.

fig_properties( figname=None, xsize=8, ysize=6, transparent=False, dpi=160, font_ratio=1, bbox_inches='tight')
217    def __init__(
218        self,
219        figname=None,
220        xsize=8,
221        ysize=6,
222        transparent=False,
223        dpi=160,
224        font_ratio=1,
225        bbox_inches="tight",
226    ):
227
228        self.figname = figname
229        """Path to the figure name to be saved"""
230        self.xsize = xsize
231        """Width of the figure in inch"""
232        self.ysize = ysize
233        """Height of the figure in inch"""
234        self.transparent = transparent
235        """Use transparency when exporting the figure, default is False"""
236        self.dpi = dpi
237        """RƩsolution (dpi), int, default is 80"""
238        self.font_ratio = font_ratio
239        """Global font ratio"""
240        self.bbox_inches = bbox_inches
241        """Constraint of the boundingbox of each ax (see matplotlib docuentation)"""
figname

Path to the figure name to be saved

xsize

Width of the figure in inch

ysize

Height of the figure in inch

transparent

Use transparency when exporting the figure, default is False

dpi

RƩsolution (dpi), int, default is 80

font_ratio

Global font ratio

bbox_inches

Constraint of the boundingbox of each ax (see matplotlib docuentation)

def update(self, **kwargs):
243    def update(self, **kwargs):
244        """Update the class attributes using kwarg (dictionnary)"""
245
246        for key, values in kwargs.items():
247            setattr(self, key, values)

Update the class attributes using kwarg (dictionnary)

def change(self, figure):
249    def change(self, figure):
250        """Apply change to the current figure `figure`
251        Parameter:
252        ----------
253        figure : list or tuple
254            list of (fig, ax), figure and ax of a matplotlib subplot.
255        Return:
256        -------
257            a tuple of the modified (fig,ax)
258        """
259
260        fig, ax = figure
261
262        fig.set_figheight(self.ysize)
263        fig.set_figwidth(self.xsize)
264
265        plt.rc(
266            "font", size=plt.rcParams["font.size"] * self.font_ratio
267        )  # controls default text sizes
268        plt.rc(
269            "axes", titlesize=plt.rcParams["axes.titlesize"] * self.font_ratio
270        )  # fontsize of the axes title
271        plt.rc(
272            "axes", labelsize=plt.rcParams["axes.labelsize"] * self.font_ratio
273        )  # fontsize of the x and y labels
274        plt.rc(
275            "xtick", labelsize=plt.rcParams["xtick.labelsize"] * self.font_ratio
276        )  # fontsize of the tick labels
277        plt.rc(
278            "ytick", labelsize=plt.rcParams["ytick.labelsize"] * self.font_ratio
279        )  # fontsize of the tick labels
280        plt.rc(
281            "legend", fontsize=plt.rcParams["legend.fontsize"] * self.font_ratio
282        )  # legend fontsize
283        plt.rc(
284            "figure", titlesize=plt.rcParams["figure.titlesize"] * self.font_ratio
285        )  # fontsize of the figure title
286
287        if self.figname is not None:
288
289            head_path, basename = os.path.split(self.figname)
290
291            if len(head_path) > 0 and not os.path.exists(head_path):
292                os.makedirs(head_path)
293
294            fig.savefig(
295                self.figname,
296                transparent=self.transparent,
297                dpi=self.dpi,
298                bbox_inches=self.bbox_inches,
299            )
300
301        return fig, ax

Apply change to the current figure figure

Parameter:

figure : list or tuple list of (fig, ax), figure and ax of a matplotlib subplot.

Return:

a tuple of the modified (fig,ax)
def save_figure(*args, **kwargs):
 95    def wrapper(*args, **kwargs):
 96
 97        bound = sig.bind(*args, **kwargs)
 98        bound.apply_defaults()
 99
100        for name, value in bound.arguments.items():
101            if name in annotations:
102
103                target_type = annotations[name]
104
105                args_ = get_args(target_type)
106
107                if target_type is None and len(args_) == 0:
108                    args_ = (type(None),)
109                    target_type = type(None)
110
111                if not type(value) in args_:
112
113                    if len(args_) > 1 and type(None) in args_:
114
115                        converted = False
116                        for t in args_:
117
118                            if t is not type(None):
119
120                                if value is not None:
121                                    try:
122                                        print(
123                                            f"</> Warning: Arg '{name}' of type {type(value)} is being"
124                                            f" converted to {t}"
125                                        )
126                                        bound.arguments[name] = t(value)
127                                        converted = True
128                                    except:
129                                        pass
130
131                                if converted:
132                                    break
133
134                        if not converted:
135                            raise TypeError(
136                                f"</> Error: Arg '{name}' must be a type of "
137                                f" {args_}, got {value}"
138                                f" ({type(value).__name__})"
139                            )
140
141                    else:
142                        if not isinstance(value, target_type):
143                            try:
144                                print(
145                                    f"</> Warning: Arg '{name}' of type {type(value)} is being"
146                                    f" converted to {target_type}"
147                                )
148                                bound.arguments[name] = target_type(value)
149                            except Exception:
150                                raise TypeError(
151                                    f"</> Error: Arg '{name}' must be a type of "
152                                    f" {target_type.__name__}, got {value}"
153                                    f" ({type(value).__name__})"
154                                )
155
156        return func(*bound.args, **bound.kwargs)

Save a figure.

Parameters:

fig: fig object returned by matplotlib.subplot the figure to save figname: str Path to the figure xsize: int width of the figure in inch ysize: int height of the figure in inch transparent: bool, default is False use transparency dpi : int resolution of the figure, default is 80

def generate_palette(base_color, n, variation='hue'):
329def generate_palette(base_color, n, variation="hue"):
330    """
331    Generate a palette of colors from a base color.
332    Parameter:
333    ---------
334    base_color: str
335        matplotlib color string
336    n: int
337        number of color to generate
338    variation: 'hue' | 'brightness'
339        how to generate the color palette, by changing the hue or the brighness of the base color.
340    Return: a list of colors
341    """
342    # Convertir la couleur de base en format RGB normalisƩ (0-1)
343    rgb = mcolors.to_rgb(base_color)
344
345    # Convertir en HSV
346    h, s, v = colorsys.rgb_to_hsv(*rgb)
347
348    # GƩnƩrer n couleurs en modifiant la teinte ou valeur
349    palette = []
350    for i in range(n):
351        new_h = h
352        if variation == "hue":
353            new_h = (h + i / n) % 1.0  # cycle dans le cercle chromatique
354        new_v = v
355        if variation == "brightness":
356            new_v = max(0.1, min(1.0, v * (0.5 + i / (2 * n))))  # Ʃviter le noir complet
357
358        new_rgb = colorsys.hsv_to_rgb(new_h, s, new_v)
359        palette.append(new_rgb)
360
361    return palette

Generate a palette of colors from a base color.

Parameter:

base_color: str matplotlib color string n: int number of color to generate variation: 'hue' | 'brightness' how to generate the color palette, by changing the hue or the brighness of the base color. Return: a list of colors

def plot_chro(*args, **kwargs):
 95    def wrapper(*args, **kwargs):
 96
 97        bound = sig.bind(*args, **kwargs)
 98        bound.apply_defaults()
 99
100        for name, value in bound.arguments.items():
101            if name in annotations:
102
103                target_type = annotations[name]
104
105                args_ = get_args(target_type)
106
107                if target_type is None and len(args_) == 0:
108                    args_ = (type(None),)
109                    target_type = type(None)
110
111                if not type(value) in args_:
112
113                    if len(args_) > 1 and type(None) in args_:
114
115                        converted = False
116                        for t in args_:
117
118                            if t is not type(None):
119
120                                if value is not None:
121                                    try:
122                                        print(
123                                            f"</> Warning: Arg '{name}' of type {type(value)} is being"
124                                            f" converted to {t}"
125                                        )
126                                        bound.arguments[name] = t(value)
127                                        converted = True
128                                    except:
129                                        pass
130
131                                if converted:
132                                    break
133
134                        if not converted:
135                            raise TypeError(
136                                f"</> Error: Arg '{name}' must be a type of "
137                                f" {args_}, got {value}"
138                                f" ({type(value).__name__})"
139                            )
140
141                    else:
142                        if not isinstance(value, target_type):
143                            try:
144                                print(
145                                    f"</> Warning: Arg '{name}' of type {type(value)} is being"
146                                    f" converted to {target_type}"
147                                )
148                                bound.arguments[name] = target_type(value)
149                            except Exception:
150                                raise TypeError(
151                                    f"</> Error: Arg '{name}' must be a type of "
152                                    f" {target_type.__name__}, got {value}"
153                                    f" ({type(value).__name__})"
154                                )
155
156        return func(*bound.args, **bound.kwargs)

Plot a temporal chonic of values

Parameters:

data: np.ndarray of dimenion 2. data to plot as a matrix of 2 dimension. t_axis : int the axis of the time in data, default is 1 outlets_name: list the list of the outlets name columns : list the column to be plotted in t_axis direction dt: float the timestep xtics : list list of date for the xtics. The format must be automatically read by numpy.Datetime date_range: list list of [date_start, date_end, timedelta] to generate the xtics figure: tuple input figure as (fig,ax) to add a new curve ax_settings: dict or class ax_properties object or dict with any attribute of class ax_properties fig_settings: dict or class ax_properties object or dict with any attribute of class fig_settings plot_settings: dict or class plot_properties object or dict with any attribute of class plot_settings

def plot_hydrograph( model: smash.core.model.model.Model | None = None, columns: list | tuple = [], outlets_name: list | tuple = [], plot_rainfall: bool = True, plot_qobs: bool = True, figure: list | tuple | None = None, ax_settings: dict | ax_properties = {}, fig_settings: dict | fig_properties = {}, plot_settings_sim: dict | plot_properties = {}, plot_settings_obs: dict | plot_properties = {}):
470def plot_hydrograph(
471    model: Model | None = None,
472    columns: list | tuple = [],
473    outlets_name: list | tuple = [],
474    plot_rainfall: bool = True,
475    plot_qobs: bool = True,
476    figure: list | tuple | None = None,
477    ax_settings: dict | ax_properties = {},
478    fig_settings: dict | fig_properties = {},
479    plot_settings_sim: dict | plot_properties = {},
480    plot_settings_obs: dict | plot_properties = {},
481):
482    """
483    Plot an hydrograph from a smash model
484    Parameters:
485    -----------
486    model: a smash model object
487        a smash model object
488    outlets_name: list
489        the list of the outlets name
490    columns : list
491        the column to be plotted in t_axis direction
492    figure: tuple
493        input figure as (fig,ax) to add a new curve
494    ax_settings: dict or class ax_properties
495        object or dict with any attribute of class ax_properties
496    fig_settings: dict or class ax_properties
497        object or dict with any attribute of class fig_settings
498    plot_settings_sim: dict or class ax_properties
499        object or dict with any attribute of class plot_settings.Control the simulated curve
500    plot_settings_obs: dict or class ax_properties
501        object or dict with any attribute of class plot_settings. Control the observed curve
502
503    """
504    if model is None:
505        raise ValueError("Input smash model object is None.")
506
507    if isinstance(ax_settings, dict):
508        default_ax_settings = ax_properties(
509            xlabel="Time", ylabel="discharges m^3/s", xtics_fontsize=10, ytics_fontsize=10
510        )
511        default_ax_settings.update(**ax_settings)
512    else:
513        default_ax_settings = ax_properties(**ax_settings.__dict__)
514
515    if isinstance(fig_settings, dict):
516        fig_settings = fig_properties(**fig_settings)
517    else:
518        fig_settings = fig_properties(**fig_settings.__dict__)
519
520    if isinstance(plot_settings_sim, dict):
521        default_plot_settings_sim = plot_properties(
522            ls="-",
523            lw="2",
524            marker="",
525            markersize=4,
526            color="blue",
527            label="Sim",
528        )
529        default_plot_settings_sim.update(**plot_settings_sim)
530    else:
531        default_plot_settings_sim = plot_properties(**plot_settings_sim.__dict__)
532
533    # default color for multi curves: same color but different line type
534    if len(columns) >= 2:
535        color = "blue"
536    else:
537        color = "black"
538
539    if isinstance(plot_settings_obs, dict):
540        default_plot_settings_obs = plot_properties(
541            ls="--",
542            lw="1.5",
543            marker="",
544            markersize=4,
545            color=color,
546            label="Obs",
547        )
548        default_plot_settings_obs.update(**plot_settings_obs)
549    else:
550        default_plot_settings_obs = plot_properties()
551
552    # manage date here
553    date_deb = datetime.datetime.fromisoformat(
554        model.setup.start_time
555    ) + datetime.timedelta(seconds=int(model.setup.dt))
556    date_end = datetime.datetime.fromisoformat(model.setup.end_time)
557    date_range = [date_deb, date_end, model.setup.dt]
558
559    if figure is None:
560        if plot_rainfall:
561            fig, (ax2, ax1) = plt.subplots(2, 1, height_ratios=[1, 4])
562            fig.subplots_adjust(hspace=0)
563            figure = [fig, ax1, ax2]
564        else:
565            fig, ax2 = plt.subplots()
566            figure = [fig, ax1]
567    else:
568        if plot_rainfall:
569            fig = figure[0]
570            ax1 = figure[1]
571            ax2 = figure[2]
572        else:
573            fig = figure[0]
574            ax1 = figure[1]
575
576    fig, ax = default_ax_settings.change(figure=(fig, ax1))
577
578    if plot_qobs:
579        fig, ax1 = plot_chro(
580            np.where(model.response_data.q<0, np.nan, model.response_data.q),
581            date_range=date_range,
582            columns=columns,
583            outlets_name=["obs_" + name for name in outlets_name],
584            figure=(fig, ax1),
585            ax_settings=default_ax_settings,
586            fig_settings=fig_settings,
587            plot_settings=default_plot_settings_obs,
588        )
589
590    fig, ax1 = plot_chro(
591        model.response.q,
592        date_range=date_range,
593        columns=columns,
594        outlets_name=["sim_" + name for name in outlets_name],
595        figure=(fig, ax1),
596        ax_settings=default_ax_settings,
597        fig_settings=fig_settings,
598        plot_settings=default_plot_settings_sim,
599    )
600
601    xtics = np.arange(
602        np.datetime64(date_range[0]),
603        np.datetime64(date_range[1] + datetime.timedelta(seconds=int(date_range[2]))),
604        np.timedelta64(int(date_range[2]), "s"),
605    )
606
607    axes_list = [ax1]
608
609    if plot_rainfall:
610
611        if len(columns) > 0:
612            col = columns[0]
613        else:
614            col = 0
615
616        ax2.bar(
617            xtics[:],
618            model.atmos_data.mean_prcp[col, :],
619            label="Average rainfall (mm)",
620            width=np.timedelta64(int(date_range[2]), "s"),
621            color="blue",
622        )
623
624        ax2.invert_yaxis()
625        ax2.grid(alpha=0.7, ls="--")
626        ax2.get_xaxis().set_visible(False)
627        ax2.set_ylim(bottom=1.2 * max(model.atmos_data.mean_prcp[0, :]), top=0.0)
628        ax2.set_ylabel("Average rainfall (mm)")
629
630        axes_list.append(ax2)
631
632    fig, ax = fig_settings.change(figure=(fig, tuple(axes_list)))
633
634    return fig, ax

Plot an hydrograph from a smash model

Parameters:

model: a smash model object a smash model object outlets_name: list the list of the outlets name columns : list the column to be plotted in t_axis direction figure: tuple input figure as (fig,ax) to add a new curve ax_settings: dict or class ax_properties object or dict with any attribute of class ax_properties fig_settings: dict or class ax_properties object or dict with any attribute of class fig_settings plot_settings_sim: dict or class ax_properties object or dict with any attribute of class plot_settings.Control the simulated curve plot_settings_obs: dict or class ax_properties object or dict with any attribute of class plot_settings. Control the observed curve

def plot_catchment_surface_error( mesh: dict = None, ax_settings: dict | ax_properties = {}, fig_settings: dict | fig_properties = {}):
637def plot_catchment_surface_error(
638    mesh: dict = None,
639    ax_settings: dict | ax_properties = {},
640    fig_settings: dict | fig_properties = {},
641):
642    """
643    Plot the misfit criteria between the simulated and observed discharges.
644    Parameters:
645    -----------
646    values: np.ndarray
647        The result of the discharge misfit for all outlets.
648    names: np.ndarray
649        Outlets name or code stored in an np.ndarray.
650    columns: list | None
651        Columns of the np.ndarray to plot
652    misfit: str
653        Criteria to plot. choice are ['nse', 'nnse', 'rmse', 'nrmse', 'se', 'kge']"
654    figure: tuple
655        input figure as (fig,ax) to add a new curve
656    ax_settings: dict or class ax_properties
657        object or dict with any attribute of class ax_properties
658    fig_settings: dict or class ax_properties
659        object or dict with any attribute of class fig_settings
660    """
661
662    if isinstance(ax_settings, dict):
663        default_ax_settings = ax_properties(
664            title="Catchment surface error (Ssim-Sobs)/Sobs *100",
665            ylabel="Surface error %",
666            xlabel="Outlets",
667            xticklabels_rotation=45,
668            xtics_fontsize=6,
669        )
670        default_ax_settings.update(**ax_settings)
671    else:
672        default_ax_settings = ax_properties(**ax_settings.__dict__)
673
674    if isinstance(fig_settings, dict):
675        fig_settings = fig_properties(**fig_settings)
676    else:
677        fig_settings = fig_properties(**fig_settings.__dict__)
678
679    if len(mesh["code"]) == 0:
680        print("Cannot plot this, the mesh has no gauge !")
681        return None, None
682
683    plt.rcParams.update(plt.rcParamsDefault)
684    fig, ax = plt.subplots()
685    fig, ax = default_ax_settings.change(figure=(fig, ax))
686
687    surface_error = (mesh["area_dln"] - mesh["area"]) / mesh["area"] * 100
688
689    fig, ax = default_ax_settings.change(figure=(fig, ax))
690    bar_container = ax.bar(
691        mesh["code"], surface_error, color="grey", tick_label=mesh["code"]
692    )
693
694    ax.bar_label(
695        bar_container,
696        fmt=lambda x: f"{x:.2f}",
697        fontsize=default_ax_settings.barlabel_fontsize,
698    )
699
700    fig, ax = default_ax_settings.change(figure=(fig, ax))
701    fig, ax = fig_settings.change(figure=(fig, ax))
702
703    return fig, ax

Plot the misfit criteria between the simulated and observed discharges.

Parameters:

values: np.ndarray The result of the discharge misfit for all outlets. names: np.ndarray Outlets name or code stored in an np.ndarray. columns: list | None Columns of the np.ndarray to plot misfit: str Criteria to plot. choice are ['nse', 'nnse', 'rmse', 'nrmse', 'se', 'kge']" figure: tuple input figure as (fig,ax) to add a new curve ax_settings: dict or class ax_properties object or dict with any attribute of class ax_properties fig_settings: dict or class ax_properties object or dict with any attribute of class fig_settings

def plot_catchment_surface_consistency( mesh: dict = None, label: bool = True, ax_settings: dict | ax_properties = {}, fig_settings: dict | fig_properties = {}, plot_settings: dict | plot_properties = {}):
706def plot_catchment_surface_consistency(
707    mesh: dict = None,
708    label: bool = True,
709    ax_settings: dict | ax_properties = {},
710    fig_settings: dict | fig_properties = {},
711    plot_settings: dict | plot_properties = {},
712):
713    """
714    Plot the modeled surface vs the observed surface
715    Parameters:
716    -----------
717    mesh: dict, optional
718        The mesh of the Smash model, defaults to None
719    ax_settings: dict or class ax_properties
720        object or dict with any attribute of class ax_properties
721    fig_settings: dict or class ax_properties
722        object or dict with any attribute of class fig_settings
723    plot_settings: dict or class plot_properties
724        object or dict with any attribute of class plot_settings.Control the simulated curve
725    """
726
727    if isinstance(ax_settings, dict):
728        default_ax_settings = ax_properties(
729            title="Modeled and observed surface consistency",
730            ylabel="Modeled surface",
731            xlabel="Observed surface",
732        )
733        default_ax_settings.update(**ax_settings)
734    else:
735        default_ax_settings = ax_properties(**ax_settings.__dict__)
736
737    if isinstance(fig_settings, dict):
738        fig_settings = fig_properties(**fig_settings)
739    else:
740        fig_settings = fig_properties(**fig_settings.__dict__)
741
742    if isinstance(plot_settings, dict):
743        default_plot_settings = plot_properties(
744            marker="+",
745            markersize=12,
746            color="blue",
747        )
748        default_plot_settings.update(**plot_settings)
749    else:
750        default_plot_settings = plot_properties(**plot_settings.__dict__)
751
752    if len(mesh["code"]) == 0:
753        print("Cannot plot this, the mesh has no gauge !")
754        return None, None
755
756    surface_model = mesh["area_dln"] / 1000.0**2.0
757    surface_obs = mesh["area"] / 1000.0**2.0
758
759    plt.rcParams.update(plt.rcParamsDefault)
760    fig, ax = plt.subplots()
761    fig, ax = default_ax_settings.change(figure=(fig, ax))
762
763    ax.plot(
764        surface_obs,
765        surface_model,
766        markersize=default_plot_settings.markersize,
767        marker=default_plot_settings.marker,
768        color=default_plot_settings.color,
769        linestyle="None",
770    )
771    ax.plot(
772        np.linspace(min(surface_obs), max(surface_obs), 10),
773        np.linspace(min(surface_obs), max(surface_obs), 10),
774        linewidth=2,
775        color="grey",
776    )
777
778    if label:
779        ha = ("left", "right")
780        for i, label in enumerate(mesh["code"]):
781            ax.annotate(
782                label,  # this is the text
783                (
784                    surface_obs[i],
785                    surface_model[i],
786                ),  # these are the coordinates to position the label
787                textcoords="data",  # how to position the text
788                xytext=(
789                    surface_obs[i],
790                    surface_model[i],
791                ),  # distance from text to points (x,y)
792                ha=ha[i % 2],  # horizontal alignment can be left, right or center
793                color="red",
794                fontsize=default_ax_settings.annotate_fontsize,
795            )
796
797    ax.set(xlabel=default_ax_settings.xlabel, ylabel=default_ax_settings.ylabel)
798
799    fig, ax = fig_settings.change(figure=(fig, ax))
800
801    return fig, ax

Plot the modeled surface vs the observed surface

Parameters:

mesh: dict, optional The mesh of the Smash model, defaults to None ax_settings: dict or class ax_properties object or dict with any attribute of class ax_properties fig_settings: dict or class ax_properties object or dict with any attribute of class fig_settings plot_settings: dict or class plot_properties object or dict with any attribute of class plot_settings.Control the simulated curve

def plot_mesh( mesh: dict = None, coef_hydro: float = 99.0, catchment_polygon: None | pandas.core.frame.DataFrame = None, ax_settings: dict | ax_properties = {}, fig_settings: dict | fig_properties = {}):
804def plot_mesh(
805    mesh: dict = None,
806    coef_hydro: float = 99.0,
807    catchment_polygon: None | DataFrame = None,
808    ax_settings: dict | ax_properties = {},
809    fig_settings: dict | fig_properties = {},
810):
811    """
812    Plot the mesh of a smash model
813    Parameters:
814    -----------
815    mesh: a smash mesh as dictionary
816        a smash model object
817    coef_hydro: float
818        the coefficient to colorize the hydrographic network accodring the cumulative
819        surface. default is 99% so that 99% of the cell will be hidden.
820    ax_settings: dict or class ax_properties
821        object or dict with any attribute of class ax_properties
822    fig_settings: dict or class ax_properties
823        object or dict with any attribute of class fig_settings
824    """
825
826    if mesh is not None:
827        if isinstance(mesh, dict):
828            pass
829        else:
830            raise ValueError("mesh must be a dict")
831    else:
832        raise ValueError(
833            "model or mesh are mandatory and must be a dict or a smash Model object"
834        )
835
836    if isinstance(ax_settings, dict):
837        default_ax_settings = ax_properties(
838            title="Mesh of the Smash model",
839            xlabel="x_coords",
840            ylabel="y_coords",
841        )
842        default_ax_settings.update(**ax_settings)
843    else:
844        default_ax_settings = ax_properties(**ax_settings.__dict__)
845
846    if isinstance(fig_settings, dict):
847        fig_settings = fig_properties(**fig_settings)
848    else:
849        fig_settings = fig_properties(**fig_settings.__dict__)
850
851    # mesh["active_cell"]
852    gauge = mesh["gauge_pos"]
853    stations = mesh["code"]
854    flow_acc = mesh["flwacc"]
855    na = mesh["active_cell"] == 0
856
857    flow_accum_bv = np.where(na, 0.0, flow_acc.data / 1000000.0)
858    surfmin = (1.0 - coef_hydro / 100.0) * np.max(flow_accum_bv)
859    mask_flow = flow_accum_bv < surfmin
860    flow_plot = np.where(mask_flow, np.nan, flow_accum_bv)
861    flow_plot = np.where(na, np.nan, flow_plot)
862
863    plt.rcParams.update(plt.rcParamsDefault)
864    fig, ax = plt.subplots()
865    fig, ax = default_ax_settings.change(figure=(fig, ax))
866
867    bbox = geo_toolbox.get_bbox_from_smash_mesh(mesh)
868    extent = (bbox["left"], bbox["right"], bbox["bottom"], bbox["top"])
869
870    active_cell = np.where(na, np.nan, mesh["active_cell"])
871    cmap = ListedColormap(["lightgray"])
872    ax.imshow(active_cell, cmap=cmap, extent=extent)
873
874    myblues = matplotlib.colormaps["Blues"]
875    cmp = ListedColormap(myblues(np.linspace(0.30, 1.0, 265)))
876    im = ax.imshow(flow_plot, cmap=cmp, extent=extent)
877
878    if catchment_polygon is not None:
879        # catchment_polygon = gpd.read_file(outlets_shapefile)
880        catchment_polygon.plot(ax=ax, facecolor="none", edgecolor="black")
881
882    # create an axes on the right side of ax. The width of cax will be 5%
883    # of ax and the padding between cax and ax will be fixed at 0.05 inch.
884    divider = make_axes_locatable(ax)
885    cax = divider.append_axes("right", size="5%", pad=0.05)
886
887    fig.colorbar(
888        im, cmap="Blues", ax=ax, label="Cumulated surface (km²)", shrink=0.75, cax=cax
889    )
890
891    pos_y = -5
892    ha = "right"
893    for i in range(len(stations)):
894        if pos_y > 0:
895            pos_y = -10
896        else:
897            pos_y = 5
898        # pos_y=-1*pos_y
899
900        if ha == "right":
901            ha = "left"
902            pos_x = 5
903        else:
904            ha = "right"
905            # pos_x = -5
906
907        coord = geo_toolbox.rowcol_to_xy(
908            gauge[i][0],
909            gauge[i][1],
910            mesh["xmin"],
911            mesh["ymax"],
912            mesh["xres"],
913            mesh["yres"],
914        ) + np.array(
915            [
916                mesh["dx"][gauge[i][0], gauge[i][1]] / 2,
917                -mesh["dx"][gauge[i][0], gauge[i][1]] / 2,
918            ]
919        )
920
921        code = stations[i]
922        ax.plot(coord[0], coord[1], color="green", marker="o", markersize=6)
923        ax.annotate(
924            code,  # this is the text
925            # these are the coordinates to position the label
926            (coord[0], coord[1]),
927            # textcoords="offset points",  # how to position the text
928            # xytext=(pos_x, pos_y),  # distance from text to points (x,y)
929            textcoords="data",  # how to position the text
930            xytext=(coord[0], coord[1]),  # distance from text to points (x,y)
931            ha=ha,  # horizontal alignment can be left, right or center
932            color="red",
933            fontsize=10,
934        )
935
936    fig, ax = default_ax_settings.change(figure=(fig, ax))
937
938    fig, ax = fig_settings.change(figure=(fig, ax))
939
940    return fig, ax

Plot the mesh of a smash model

Parameters:

mesh: a smash mesh as dictionary a smash model object coef_hydro: float the coefficient to colorize the hydrographic network accodring the cumulative surface. default is 99% so that 99% of the cell will be hidden. ax_settings: dict or class ax_properties object or dict with any attribute of class ax_properties fig_settings: dict or class ax_properties object or dict with any attribute of class fig_settings

def plot_xy_quantile( res_quantile, X, Y, res_quantile_obs=None, gauge_pos=None, figure=None, ax_settings: dict | ax_properties = {}, fig_settings: dict | fig_properties = {}, plot_settings: dict | plot_properties = {}):
 943def plot_xy_quantile(
 944    res_quantile,
 945    X,
 946    Y,
 947    res_quantile_obs=None,
 948    gauge_pos=None,
 949    figure=None,
 950    ax_settings: dict | ax_properties = {},
 951    fig_settings: dict | fig_properties = {},
 952    plot_settings: dict | plot_properties = {},
 953):
 954    """
 955    Plot the discharges quantiles fitting at X,Y coordinates.
 956    Parameters:
 957    -----------
 958    res_quantile: dict
 959        The result of the discharge quantile computation.
 960    res_quantile_obs: dict
 961        The results of the observed discharges quantile. res_quantile_obs is a dict and must be computed by the function smashbox.stats.stats.quantile_obs()
 962    gauge_pos: int
 963        gauge_pos is the index of gauge in the Smash mesh for which the quantile_discharge are provided to the function.
 964    X: int
 965        Coordinates of the pixel in the row directions (X means row)
 966    Y: int
 967        Coordinates of the pixel in the column directions (Y means column)
 968    figure: tuple
 969        input figure as (fig,ax) to add a new curve
 970    ax_settings: dict or class ax_properties
 971        object or dict with any attribute of class ax_properties
 972    fig_settings: dict or class ax_properties
 973        object or dict with any attribute of class fig_settings
 974    """
 975    if isinstance(ax_settings, dict):
 976        default_ax_settings = ax_properties(
 977            xscale="log",
 978            xlabel=f"Return period (*{res_quantile['chunk_size']} days)",
 979            ylabel="Discharges (m³/s)",
 980            grid=True,
 981            legend=True,
 982        )
 983        default_ax_settings.update(**ax_settings)
 984    else:
 985        default_ax_settings = ax_properties(**ax_settings.__dict__)
 986
 987    if isinstance(fig_settings, dict):
 988        fig_settings = fig_properties(**fig_settings)
 989    else:
 990        fig_settings = fig_properties(**fig_settings.__dict__)
 991
 992    if isinstance(plot_settings, dict):
 993        default_plot_settings = plot_properties(markersize=10)
 994        default_plot_settings.update(**plot_settings)
 995    else:
 996        default_plot_settings = plot_properties(**plot_settings.__dict__)
 997
 998    quantile = res_quantile["Q_th"][X, Y]
 999    maxima = res_quantile["maxima"][X, Y]
1000    T_emp = res_quantile["T_emp"]
1001    loc = res_quantile["fit_loc"][X, Y]
1002    scale = res_quantile["fit_scale"][X, Y]
1003    shape = res_quantile["fit_shape"][X, Y]
1004    fit = res_quantile["fit"]
1005
1006    sorted_data = np.sort(maxima)
1007
1008    plt.rcParams.update(plt.rcParamsDefault)
1009    if figure is None:
1010        fig, ax = plt.subplots()
1011    else:
1012        fig, ax = figure
1013
1014    fig, ax = default_ax_settings.change(figure=(fig, ax))
1015
1016    if res_quantile_obs is not None and len(res_quantile_obs.keys()) > 0:
1017        if gauge_pos is None:
1018            raise ValueError(
1019                "gauge_pos is None. gauge_pos argument must be an integer corresponding to the gauge index."
1020            )
1021        maxima_obs = res_quantile_obs["maxima"][gauge_pos, :]
1022        T_emp_obs = res_quantile_obs["Temp"][gauge_pos, :]
1023
1024        ax.plot(
1025            T_emp_obs,
1026            maxima_obs,
1027            "o",
1028            label="Observed",
1029            color="black",
1030            markersize=default_plot_settings.markersize,
1031        )
1032
1033    ax.plot(
1034        T_emp,
1035        sorted_data,
1036        "o",
1037        label="Empirical",
1038        markersize=default_plot_settings.markersize,
1039    )
1040
1041    ax.plot(
1042        res_quantile["T"],
1043        quantile,
1044        "x",
1045        label="Theorical",
1046        markersize=default_plot_settings.markersize,
1047    )
1048
1049    Trange = np.linspace(1.1, np.max(res_quantile["T"]), 50)
1050
1051    if fit == "gumbel":
1052        ax.plot(
1053            Trange,
1054            [stats.quantile_gumbel(T, loc, scale) for T in Trange],
1055            "r--",
1056            label=f"{fit} fitted",
1057            lw=default_plot_settings.lw,
1058        )
1059
1060    if fit == "gev":
1061        ax.plot(
1062            Trange,
1063            [stats.quantile_gev(T, shape, loc, scale) for T in Trange],
1064            "r--",
1065            label=f"{fit} fitted",
1066            lw=default_plot_settings.lw,
1067        )
1068
1069    if "Umax" in res_quantile.keys() and "Umin" in res_quantile.keys():
1070        if res_quantile["Umax"] is not None and res_quantile["Umin"] is not None:
1071            ax.plot(
1072                res_quantile["T"],
1073                res_quantile["Umax"][X, Y],
1074                "r--",
1075                label="Uncertainties (max)",
1076                color="grey",
1077                lw=default_plot_settings.lw,
1078            )
1079            ax.plot(
1080                res_quantile["T"],
1081                res_quantile["Umin"][X, Y],
1082                "r--",
1083                label="Uncertainties (min)",
1084                color="grey",
1085                lw=default_plot_settings.lw,
1086            )
1087
1088    fig, ax = default_ax_settings.change(figure=(fig, ax))
1089    fig, ax = fig_settings.change(figure=(fig, ax))
1090
1091    return fig, ax

Plot the discharges quantiles fitting at X,Y coordinates.

Parameters:

res_quantile: dict The result of the discharge quantile computation. res_quantile_obs: dict The results of the observed discharges quantile. res_quantile_obs is a dict and must be computed by the function smashbox.stats.stats.quantile_obs() gauge_pos: int gauge_pos is the index of gauge in the Smash mesh for which the quantile_discharge are provided to the function. X: int Coordinates of the pixel in the row directions (X means row) Y: int Coordinates of the pixel in the column directions (Y means column) figure: tuple input figure as (fig,ax) to add a new curve ax_settings: dict or class ax_properties object or dict with any attribute of class ax_properties fig_settings: dict or class ax_properties object or dict with any attribute of class fig_settings

def plot_image( matrice=array([[0., 0.], [0., 0.]]), bbox=None, vmin=None, vmax=None, mask=None, extend=None, catchment_polygon=None, figure=None, ax_settings: dict | ax_properties = {}, fig_settings: dict | fig_properties = {}):
1094def plot_image(
1095    matrice=np.zeros(shape=(2, 2)),
1096    bbox=None,
1097    vmin=None,
1098    vmax=None,
1099    mask=None,
1100    extend=None,
1101    catchment_polygon=None,
1102    figure=None,
1103    ax_settings: dict | ax_properties = {},
1104    fig_settings: dict | fig_properties = {},
1105):
1106    """
1107    Function for plotting a matrix as an image
1108
1109    Parameters
1110    ----------
1111    matrice : numpy array
1112        Matrix to be plotted
1113    bbox : list
1114        ["left","right","bottom","top"] bouding box to put x and y coordinates instead
1115    of the shape of the matrix
1116    vmin: real,
1117        minimum z value
1118    vmax: real,
1119        maximum z value
1120    mask: integer, matrix, shape of matice, contain 0 for pixels that should not be plotted
1121    catchment_polygon: dataframe containing some polygon to be plotted.
1122    Ideally it must contain the boundaries of the catchment as a polygon from a shp file
1123    read by geopanda.
1124    figure: tuple
1125        input figure as (fig,ax) to add a new curve
1126    ax_settings: dict or class ax_properties
1127        object or dict with any attribute of class ax_properties
1128    fig_settings: dict or class ax_properties
1129        object or dict with any attribute of class fig_settings
1130
1131    Examples
1132    ----------
1133    smash.utils.plot_image(mesh_france['drained_area'],bbox=bbox,title="Surfaces
1134                           drainƩes",xlabel="Longitude",ylabel="Latitude",zlabel="Surfaces drainƩes
1135                           km^2",vmin=0.0,vmax=1000,mask=mesh_france['global_active_cell'])
1136
1137    """
1138
1139    if isinstance(ax_settings, dict):
1140        ax_settings = ax_properties(**ax_settings)
1141    else:
1142        ax_settings = ax_properties(**ax_settings.__dict__)
1143
1144    if isinstance(fig_settings, dict):
1145        fig_settings = fig_properties(**fig_settings)
1146    else:
1147        fig_settings = fig_properties(**fig_settings.__dict__)
1148
1149    matrice = np.float32(matrice)
1150
1151    if bbox is not None:
1152        extent = [
1153            bbox["left"],
1154            bbox["right"],
1155            bbox["bottom"],
1156            bbox["top"],
1157        ]
1158    else:
1159        extent = None
1160
1161    if mask is not None:
1162        matrice[np.where(mask == 0)] = np.nan
1163
1164    plt.rcParams.update(plt.rcParamsDefault)
1165    if figure is None:
1166        fig, ax = plt.subplots()
1167    else:
1168        fig, ax = figure
1169
1170    if vmax is None:
1171        vmax=np.max(matrice)
1172    if vmin is None:
1173        vmin=np.min(matrice)
1174    
1175    fig, ax = ax_settings.change(figure=(fig, ax))
1176
1177    im = ax.imshow(matrice, extent=extent, vmin=vmin, vmax=vmax, cmap=ax_settings.cmap)
1178
1179    if catchment_polygon is not None:
1180        catchment_polygon.plot(ax=ax, facecolor="none", edgecolor="black")
1181
1182    # create an axes on the right side of ax. The width of cax will be 5%
1183    # of ax and the padding between cax and ax will be fixed at 0.05 inch.
1184    divider = make_axes_locatable(ax)
1185    cax = divider.append_axes("right", size="5%", pad=0.05)
1186
1187    plt.colorbar(im, label=ax_settings.clabel, cax=cax)
1188
1189    fig, ax = ax_settings.change(figure=(fig, ax))
1190    fig, ax = fig_settings.change(figure=(fig, ax))
1191
1192    return (fig, ax)

Function for plotting a matrix as an image

Parameters

matrice : numpy array Matrix to be plotted bbox : list ["left","right","bottom","top"] bouding box to put x and y coordinates instead of the shape of the matrix vmin: real, minimum z value vmax: real, maximum z value mask: integer, matrix, shape of matice, contain 0 for pixels that should not be plotted catchment_polygon: dataframe containing some polygon to be plotted. Ideally it must contain the boundaries of the catchment as a polygon from a shp file read by geopanda. figure: tuple input figure as (fig,ax) to add a new curve ax_settings: dict or class ax_properties object or dict with any attribute of class ax_properties fig_settings: dict or class ax_properties object or dict with any attribute of class fig_settings

Examples

smash.utils.plot_image(mesh_france['drained_area'],bbox=bbox,title="Surfaces drainƩes",xlabel="Longitude",ylabel="Latitude",zlabel="Surfaces drainƩes km^2",vmin=0.0,vmax=1000,mask=mesh_france['global_active_cell'])

def plot_misfit( values: numpy.ndarray = [], names: numpy.ndarray = [], columns: list | None = None, misfit: str = 'nse', figure: list | tuple | None = None, ax_settings: dict | ax_properties = {}, fig_settings: dict | fig_properties = {}):
1195def plot_misfit(
1196    values: np.ndarray = [],
1197    names: np.ndarray = [],
1198    columns: list | None = None,
1199    misfit: str = "nse",
1200    figure: list | tuple | None = None,
1201    ax_settings: dict | ax_properties = {},
1202    fig_settings: dict | fig_properties = {},
1203):
1204    """
1205    Plot the misfit criteria between the simulated and observed discharges.
1206    Parameters:
1207    -----------
1208    values: np.ndarray
1209        The result of the discharge misfit for all outlets.
1210    names: np.ndarray
1211        Outlets name or code stored in an np.ndarray.
1212    columns: list | None
1213        Columns of the np.ndarray to plot
1214    misfit: str
1215        Criteria to plot. choice are ['nse', 'nnse', 'rmse', 'nrmse', 'se', 'kge']"
1216    figure: tuple
1217        input figure as (fig,ax) to add a new curve
1218    ax_settings: dict or class ax_properties
1219        object or dict with any attribute of class ax_properties
1220    fig_settings: dict or class ax_properties
1221        object or dict with any attribute of class fig_settings
1222    """
1223
1224    if isinstance(ax_settings, dict):
1225        default_ax_settings = ax_properties(
1226            ylabel=f"{misfit} criteria",
1227            xlabel="Gauges stations",
1228            grid=True,
1229            legend=True,
1230            xticklabels_rotation=45,
1231            xtics_fontsize=8,
1232        )
1233        default_ax_settings.update(**ax_settings)
1234    else:
1235        default_ax_settings = ax_properties(**ax_settings.__dict__)
1236
1237    if isinstance(fig_settings, dict):
1238        fig_settings = fig_properties(**fig_settings)
1239    else:
1240        fig_settings = fig_properties(**fig_settings.__dict__)
1241
1242    if len(names) == 0:
1243        names = np.arange(len(values))
1244
1245    if columns is not None:
1246        values = values[columns]
1247        names = names[columns]
1248
1249    # remove nan from plot
1250    columns = list(np.isnan(values) == False)
1251    # print(columns)
1252    if len(columns) > 0:
1253        values = values[columns]
1254        names = names[columns]
1255
1256    if figure is None:
1257        fig, ax = plt.subplots()
1258    else:
1259        fig, ax = figure
1260
1261    fig, ax = default_ax_settings.change(figure=(fig, ax))
1262    bar_container = ax.bar(names, values, color="grey", tick_label=names)
1263
1264    ax.bar_label(
1265        bar_container,
1266        fmt=lambda x: f"{x:.2f}",
1267        fontsize=default_ax_settings.barlabel_fontsize,
1268    )
1269
1270    fig, ax = default_ax_settings.change(figure=(fig, ax))
1271    fig, ax = fig_settings.change(figure=(fig, ax))
1272
1273    return fig, ax

Plot the misfit criteria between the simulated and observed discharges.

Parameters:

values: np.ndarray The result of the discharge misfit for all outlets. names: np.ndarray Outlets name or code stored in an np.ndarray. columns: list | None Columns of the np.ndarray to plot misfit: str Criteria to plot. choice are ['nse', 'nnse', 'rmse', 'nrmse', 'se', 'kge']" figure: tuple input figure as (fig,ax) to add a new curve ax_settings: dict or class ax_properties object or dict with any attribute of class ax_properties fig_settings: dict or class ax_properties object or dict with any attribute of class fig_settings

def plot_outlet_stats( values_sim: numpy.ndarray | None = None, values_obs: numpy.ndarray | None = None, names: numpy.ndarray = [], columns: list | None = [], stat: str = 'max', figure: list | tuple | None = None, ax_settings: dict | ax_properties = {}, fig_settings: dict | fig_properties = {}):
1276def plot_outlet_stats(
1277    values_sim: np.ndarray | None = None,
1278    values_obs: np.ndarray | None = None,
1279    names: np.ndarray = [],
1280    columns: list | None = [],
1281    stat: str = "max",
1282    figure: list | tuple | None = None,
1283    ax_settings: dict | ax_properties = {},
1284    fig_settings: dict | fig_properties = {},
1285):
1286    """
1287    Plot a statistical criteria at a given list of outlet.
1288    Parameters:
1289    -----------
1290    values_sim: np.ndarray or None
1291        The result of the simulated stat for all outlets.
1292    values_obs: np.ndarray or None
1293        The result of the observed stat for all outlets.
1294    names: np.ndarray
1295        Outlets name or code stored in an np.ndarray.
1296    columns: list | None
1297        Columns of the np.ndarray to plot
1298    stat: str
1299        Criteria to plot. choice are ['max', 'min', 'mean', 'median', 'q20', 'q80']"
1300    figure: tuple
1301        input figure as (fig,ax) to add a new curve
1302    ax_settings: dict or class ax_properties
1303        object or dict with any attribute of class ax_properties
1304    fig_settings: dict or class ax_properties
1305        object or dict with any attribute of class fig_settings
1306    """
1307
1308    if isinstance(ax_settings, dict):
1309        default_ax_settings = ax_properties(
1310            ylabel=f"{stat} criteria",
1311            xlabel="Gauges stations",
1312            grid=True,
1313            legend=True,
1314            xticklabels_rotation=45,
1315            xtics_fontsize=6,
1316        )
1317        default_ax_settings.update(**ax_settings)
1318    else:
1319        default_ax_settings = ax_properties(**ax_settings.__dict__)
1320
1321    if isinstance(fig_settings, dict):
1322        fig_settings = fig_properties(**fig_settings)
1323    else:
1324        fig_settings = fig_properties(**fig_settings.__dict__)
1325
1326    if columns is not None:
1327        if values_sim is not None:
1328            values_sim = values_sim[columns]
1329
1330        if values_obs is not None:
1331            values_obs = values_obs[columns]
1332
1333        names = names[columns]
1334
1335    if np.all(values_obs == -99.0):
1336        values_obs = None
1337
1338    if values_sim is not None and values_obs is not None:
1339        if values_obs.size != values_sim.size:
1340            raise ValueError("values_sim and values_obs must have the same size !")
1341
1342    if figure is None:
1343        fig, ax = plt.subplots()
1344    else:
1345        fig, ax = figure
1346
1347    fig, ax = default_ax_settings.change(figure=(fig, ax))
1348
1349    x = np.arange(len(names))
1350    width = 0.25  # the width of the bars
1351
1352    multiplier = 0
1353
1354    if values_sim is not None:
1355        offset = width * multiplier
1356        ax.bar(x + offset, values_sim, width, label="obs")
1357        multiplier += 1
1358        # ax.bar_label(rects, padding=3)
1359
1360    if values_obs is not None:
1361        offset = width * multiplier
1362        ax.bar(x + offset, values_obs, width, label="sim")
1363        # ax.bar_label(rects, padding=3)
1364        # multiplier += 1
1365
1366    ax.set_xticks(x + width, names)
1367
1368    # bar_container = ax.bar(names, values, color="grey", tick_label=names)
1369
1370    # ax.bar_label(
1371    #     bar_container,
1372    #     fmt=lambda x: f"{x:.2f}",
1373    #     fontsize=default_ax_settings.barlabel_fontsize,
1374    # )
1375
1376    fig, ax = default_ax_settings.change(figure=(fig, ax))
1377    fig, ax = fig_settings.change(figure=(fig, ax))
1378
1379    return fig, ax

Plot a statistical criteria at a given list of outlet.

Parameters:

values_sim: np.ndarray or None The result of the simulated stat for all outlets. values_obs: np.ndarray or None The result of the observed stat for all outlets. names: np.ndarray Outlets name or code stored in an np.ndarray. columns: list | None Columns of the np.ndarray to plot stat: str Criteria to plot. choice are ['max', 'min', 'mean', 'median', 'q20', 'q80']" figure: tuple input figure as (fig,ax) to add a new curve ax_settings: dict or class ax_properties object or dict with any attribute of class ax_properties fig_settings: dict or class ax_properties object or dict with any attribute of class fig_settings

def plot_misfit_map( values: numpy.ndarray = [], names: numpy.ndarray = [], mesh=None, misfit: str = 'nse', coef_hydro=99.0, catchment_polygon: None | pandas.core.frame.DataFrame = None, ax_settings: dict | ax_properties = {}, fig_settings: dict | fig_properties = {}, plot_settings: dict | plot_properties = {}):
1382def plot_misfit_map(
1383    values: np.ndarray = [],
1384    names: np.ndarray = [],
1385    mesh=None,
1386    misfit: str = "nse",
1387    coef_hydro=99.0,
1388    catchment_polygon: None | DataFrame = None,
1389    ax_settings: dict | ax_properties = {},
1390    fig_settings: dict | fig_properties = {},
1391    plot_settings: dict | plot_properties = {},
1392):
1393    """
1394    Map plot of the misfit criteria between the simulated and observed discharges.
1395    Parameters:
1396    -----------
1397    values: np.ndarray
1398        The result of the discharge misfit for all outlets.
1399    names: np.ndarray
1400        Outlets name or code stored in an np.ndarray.
1401    mesh: None | dict
1402        The mesh of the Smash model as dict
1403    misfit: str
1404        Criteria to plot. choice are ['nse', 'nnse', 'rmse', 'nrmse', 'se', 'kge']"
1405    figure: tuple
1406        input figure as (fig,ax) to add a new curve
1407    ax_settings: dict or class ax_properties
1408        object or dict with any attribute of class ax_properties
1409    fig_settings: dict or class ax_properties
1410        object or dict with any attribute of class fig_settings
1411    plot_settings_sim: dict or class ax_properties
1412        object or dict with any attribute of class plot_settings.
1413    """
1414    if mesh is not None:
1415        if isinstance(mesh, dict):
1416            pass
1417        else:
1418            raise ValueError("mesh must be a dict")
1419    else:
1420        raise ValueError(
1421            "model or mesh are mandatory and must be a dict or a smash Model object"
1422        )
1423
1424    if isinstance(ax_settings, dict):
1425        default_ax_settings = ax_properties(
1426            title=f"Map of {misfit} criteria over the domain.",
1427            xlabel="x_coords",
1428            ylabel="y_coords",
1429            cmap="turbo_r",
1430        )
1431        default_ax_settings.update(**ax_settings)
1432    else:
1433        default_ax_settings = ax_properties(**ax_settings.__dict__)
1434
1435    if isinstance(fig_settings, dict):
1436        fig_settings = fig_properties(**fig_settings)
1437    else:
1438        fig_settings = fig_properties(**fig_settings.__dict__)
1439
1440    if isinstance(plot_settings, dict):
1441        default_plot_settings = plot_properties(
1442            marker="o",
1443            markersize=8,
1444        )
1445        default_plot_settings.update(**plot_settings)
1446    else:
1447        default_plot_settings = plot_properties(**plot_settings.__dict__)
1448
1449    # unset attribute color, managed separatly
1450    delattr(default_plot_settings, "color")
1451
1452    gauge = mesh["gauge_pos"]
1453    stations = mesh["code"]
1454    flow_acc = mesh["flwacc"]
1455    na = mesh["active_cell"] == 0
1456
1457    bbox = geo_toolbox.get_bbox_from_smash_mesh(mesh)
1458    extent = (bbox["left"], bbox["right"], bbox["bottom"], bbox["top"])
1459
1460    flow_accum_bv = np.where(na, 0.0, flow_acc.data)
1461    surfmin = (1.0 - coef_hydro / 100.0) * np.max(flow_accum_bv)
1462    mask_flow = flow_accum_bv < surfmin
1463    flow_plot = np.where(mask_flow, np.nan, flow_accum_bv.data)
1464    flow_plot = np.where(na, np.nan, flow_plot)
1465
1466    plt.rcParams.update(plt.rcParamsDefault)
1467    fig, ax = plt.subplots()
1468    fig, ax = default_ax_settings.change(figure=(fig, ax))
1469
1470    active_cell = np.where(na, np.nan, mesh["active_cell"])
1471    cmap = ListedColormap(["lightgray"])
1472    ax.imshow(active_cell, cmap=cmap, extent=extent)
1473
1474    myblues = matplotlib.colormaps["binary"]
1475    cmp = ListedColormap(myblues(np.linspace(0.20, 1.0, 265)))
1476    im = ax.imshow(flow_plot, cmap=cmp, extent=extent)
1477
1478    if catchment_polygon is not None:
1479        # catchment_polygon = gpd.read_file(outlets_shapefile)
1480        catchment_polygon.plot(ax=ax, facecolor="none", edgecolor="black")
1481
1482    # create an axes on the right side of ax. The width of cax will be 5%
1483    # of ax and the padding between cax and ax will be fixed at 0.05 inch.
1484    divider = make_axes_locatable(ax)
1485    cax = divider.append_axes("right", size="5%", pad=0.05)
1486
1487    fig.colorbar(
1488        im, cmap="Blues", ax=ax, label="Cumulated surface (km²)", shrink=0.75, cax=cax
1489    )
1490
1491    # define bounds for the colormap
1492    if misfit == "nse" or misfit == "nnse":
1493        vmin = 0
1494        vmax = 1
1495    elif misfit == "rmse" or misfit == "nrmse" or misfit == "se":
1496        vmin = 0
1497        vmax = np.max(values)
1498    else:
1499        vmin = np.min(values)
1500        vmax = np.max(values)
1501
1502    colormap = cm.get_cmap(default_ax_settings.cmap)
1503    cmp = ListedColormap(colormap(np.linspace(vmin, vmax, 256)))
1504
1505    ha = "right"
1506    for i in range(len(stations)):
1507
1508        if ha == "right":
1509            ha = "left"
1510            str_val = str(np.round(values[i], 2)).rjust(int(len(stations[i])))
1511            code = f"{stations[i]}\n {str_val}"
1512
1513        else:
1514            ha = "right"
1515            str_val = str(np.round(values[i], 2)).ljust(int(len(stations[i])))
1516            code = f"{stations[i]}\n {str_val}"
1517
1518        coord = geo_toolbox.rowcol_to_xy(
1519            gauge[i][0],
1520            gauge[i][1],
1521            mesh["xmin"],
1522            mesh["ymax"],
1523            mesh["xres"],
1524            mesh["yres"],
1525        )
1526
1527        ax.plot(
1528            coord[0],
1529            coord[1],
1530            color=cmp(values[i]),
1531            **default_plot_settings.__dict__,
1532        )
1533
1534        ax.annotate(
1535            code,  # this is the text
1536            # these are the coordinates to position the label
1537            (coord[0], coord[1]),
1538            textcoords="data",  # how to position the text
1539            xytext=(coord[0], coord[1]),  # distance from text to points (x,y)
1540            ha=ha,  # horizontal alignment can be left, right or center
1541            color=cmp(values[i]),
1542            fontsize=default_ax_settings.annotate_fontsize
1543            * default_ax_settings.font_ratio,
1544        )
1545
1546    import matplotlib as mpl
1547
1548    norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
1549    # create an axes on the right side of ax. The width of cax will be 5%
1550    # of ax and the padding between cax and ax will be fixed at 0.05 inch.
1551    # divider = make_axes_locatable(ax)
1552    cax = divider.append_axes("right", size="5%", pad=0.5)
1553
1554    fig.colorbar(
1555        cm.ScalarMappable(norm=norm, cmap=cmp),
1556        cmap=cmp,
1557        ax=ax,
1558        cax=cax,
1559        label=misfit,
1560        shrink=0.75,
1561        location="right",
1562    )
1563
1564    fig, ax = default_ax_settings.change(figure=(fig, ax))
1565
1566    fig, ax = fig_settings.change(figure=(fig, ax))
1567
1568    return fig, ax

Map plot of the misfit criteria between the simulated and observed discharges.

Parameters:

values: np.ndarray The result of the discharge misfit for all outlets. names: np.ndarray Outlets name or code stored in an np.ndarray. mesh: None | dict The mesh of the Smash model as dict misfit: str Criteria to plot. choice are ['nse', 'nnse', 'rmse', 'nrmse', 'se', 'kge']" figure: tuple input figure as (fig,ax) to add a new curve ax_settings: dict or class ax_properties object or dict with any attribute of class ax_properties fig_settings: dict or class ax_properties object or dict with any attribute of class fig_settings plot_settings_sim: dict or class ax_properties object or dict with any attribute of class plot_settings.