
  1import os
  2import re
  3from string import ascii_lowercase
  4from typing import (
  6    Callable,
  7    List,
  8    Tuple,
  9    Union,
 10    Dict,
 11    Type,
 12    Iterable,
 13    Optional,
 16import gin
 17import matplotlib
 18import numpy as np
 19import seaborn as sns
 20from adjustText import adjust_text
 21from matplotlib import pyplot as plt
 22from tqdm import tqdm
 24from colosseum import config
 25from colosseum.analysis.tables import get_latex_table_of_average_indicator
 26from colosseum.analysis.utils import get_available_mdps_agents_prms_and_names
 27from colosseum.analysis.utils import get_formatted_name
 28from colosseum.analysis.utils import get_logs_data, add_time_exceed_sign_to_plot
 29from colosseum.experiment.agent_mdp_interaction import MDPLoop
 30from colosseum.experiment.folder_structuring import get_experiment_config
 31from colosseum.experiment.folder_structuring import get_mdp_agent_gin_configs
 32from colosseum.experiment.utils import apply_gin_config
 33from colosseum.hardness.analysis import compute_hardness_measure
 34from colosseum.utils import ensure_folder
 35from colosseum.utils.formatter import clear_agent_mdp_class_name
 38    from colosseum.mdp import BaseMDP
 39    from matplotlib.figure import Figure
 44def _get_index(x):
 45    return clear_agent_mdp_class_name(x[0].__name__), x[1]
 48def agent_performances_per_mdp_plot(
 49    experiment_folder: str,
 50    indicator: str,
 51    figsize_scale: int = 8,
 52    standard_error: bool = False,
 53    color_palette: List[str] = matplotlib.colors.TABLEAU_COLORS.keys(),
 54    n_rows=None,
 55    savefig_folder=None,
 56    baselines=MDPLoop.get_baselines(),
 57) -> "Figure":
 58    """
 59    produces a plot in which the performance indicator of the agents is shown for each MDP for the given experiment
 60    results.
 62    Parameters
 63    ----------
 64    experiment_folder : str
 65        The folder that contains the experiment logs, MDP configurations and agent configurations.
 66    indicator : str
 67        The code name of the performance indicator that will be shown in the plot. Check `MDPLoop.get_indicators()` to
 68        get a list of the available indicators.
 69    figsize_scale : int
 70        The scale for size of the figure in the resulting plot. The default value is 8.
 71    standard_error : bool
 72        If True standard errors are computed instead of the bootstrapping estimates in seaborn.
 73    color_palette : List[str]
 74        The colors to be assigned to the agents. By default, the tableau colors are used.
 75    n_rows : int
 76        The number of rows for the grid of plots. By default, it is computed to create a square grid.
 77    savefig_folder : str
 78        The folder where the figure will be saved. By default, the figure it is not saved.
 79    baselines : List[str]
 80        The baselines to be included in the plot. Check `MDPLoop.get_baselines()` to get a list of the available
 81        baselines. By default, all baselines are shown.
 83    Returns
 84    -------
 85    Figure
 86        The matplotlib figure.
 87    """
 89    # Check the inputs
 90    assert (
 91        indicator in MDPLoop.get_indicators()
 92    ), f"Please check that the indicator given in input is one from {MDPLoop.get_indicators()}."
 93    assert all(
 94        b in MDPLoop.get_baselines() for b in baselines
 95    ), f"Please check that the baselines given in input are available."
 97    # Retrieve the MDPs and agents configurations from the experiment folder
 98    available_mdps, available_agents = get_available_mdps_agents_prms_and_names(
 99        experiment_folder
100    )
102    # Variables for the plots
103    colors_dict_agents = dict(zip(available_agents, color_palette))
104    n_plots = len(available_mdps)
105    h = int(np.ceil(n_plots ** 0.5)) if n_rows is None else n_rows
106    w = int(np.ceil(n_plots / h))
107    fig, axes = plt.subplots(
108        h,
109        w,
110        figsize=(w * figsize_scale, h * figsize_scale),
111        sharex=True,
112        # If the indicator is normalized we can also share the indicator axis
113        sharey="normaliz" in indicator,
114    )
115    if config.VERBOSE_LEVEL != 0:
116        available_mdps = tqdm(
117            sorted(available_mdps, key=lambda x: "".join(x)),
118            desc="Plotting the results",
119        )
120    else:
121        available_mdps = sorted(available_mdps, key=lambda x: "".join(x))
123    for i, available_mdp in enumerate(available_mdps):
124        ax = axes.ravel()[i]
125        mdp_formatted_name = get_formatted_name(*available_mdp)
126        group_by_mdp_individual_plot(
127            experiment_folder,
128            ax,
129            indicator,
130            *available_mdp,
131            available_agents,
132            colors_dict_agents,
133            standard_error=standard_error,
134            baselines=baselines,
135        )
136        ax.set_title(mdp_formatted_name)
137        ax.legend()
138        ax.ticklabel_format(style="sci", scilimits=(0, 4))
140    # Remove unused axes
141    for j in range(i + 1, len(axes.ravel())):
142        fig.delaxes(axes.ravel()[j])
144    # Last touches
145    plt.ticklabel_format(style="sci", scilimits=(0, 4))
146    plt.tight_layout()
148    if savefig_folder is not None:
149        os.makedirs(savefig_folder, exist_ok=True)
150        exp_name = os.path.basename(os.path.dirname(ensure_folder(experiment_folder)))
151        plt.savefig(
152            f"{ensure_folder(savefig_folder)}{indicator}-for-{exp_name}.pdf",
153            bbox_inches="tight",
154        )
158    return fig
161def get_hardness_measures_from_experiment_folder(
162    experiment_folder: str,
163    hardness_measures: Iterable[str] = ("diameter", "value_norm", "suboptimal_gaps"),
164    reduce_seed: Callable[[List[float]], float] = np.mean,
165) -> Dict[Tuple[Type["BaseMDP"], str], Dict[str, float]]:
166    """
167    retrieves the given measures of hardness for each mdp and mdp gin config in the experiment folder.
169    Parameters
170    ----------
171    experiment_folder : str
172        The folder that contains the experiment logs, MDP configurations and agent configurations.
173    hardness_measures : Iterable[str]
174        The list containing the measures of hardness to compute.
175    reduce_seed : Callable[[List[float]], float], optional
176        The function that reduces the values of the measures for different seed to a single scalar. By default, the
177        mean function is employed.
179    Returns
180    -------
181    Dict[Tuple[Type["BaseMDP"], str], Dict[str, float]]
182        The dictionary that assigns to each MDP class and gin config index the corresponding dictionary containing the
183        hardness measures names and values.
184    """
186    # Retrieve the gin configurations of the agents and MDPs
187    (
188        mdp_classes_scopes,
189        agent_classes_scopes,
190        gin_config_files_paths,
191    ) = get_mdp_agent_gin_configs(experiment_folder)
193    # Retrieve the number of seeds
194    n_seeds = get_experiment_config(experiment_folder).n_seeds
196    res = dict()
197    for mdp_class, mdp_scopes in tqdm(
198        mdp_classes_scopes.items(), desc=os.path.basename(experiment_folder)
199    ):
200        for mdp_scope in mdp_scopes:
201            apply_gin_config(gin_config_files_paths)
202            with gin.config_scope(mdp_scope):
203                res[mdp_class, mdp_scope] = {
204                    hm: reduce_seed(
205                        [
206                            compute_hardness_measure(mdp_class, dict(seed=seed), hm)
207                            for seed in range(n_seeds)
208                        ]
209                    )
210                    for hm in hardness_measures
211                }
212    return res
215def plot_labels_on_benchmarks_hardness_space(
216    experiment_folder: str,
217    text_f: Callable[[Tuple[str, str]], str],
218    color_f: Callable[[Tuple[str, str]], Union[str, None]] = lambda x: None,
219    label_f: Callable[[Tuple[str, str]], Union[str, None]] = lambda x: None,
220    ax: plt.Axes = None,
221    multiplicative_factor_xlim=1.0,
222    multiplicative_factor_ylim=1.0,
223    legend_ncol=1,
224    underneath_x_label: str = None,
225    set_ylabel=True,
226    set_legend=True,
227    xaxis_measure: Union[str, Tuple[str, Callable[["BaseMDP"], float]]] = "diameter",
228    yaxis_measure: Union[str, Tuple[str, Callable[["BaseMDP"], float]]] = "value_norm",
229    fontsize: int = 22,
230    fontsize_xlabel_underneath: int = 32,
231    text_label_fontsize=16,
233    """
234    for each agent configuration in the experiment folder, it produces a plot such that it is possible to place a text
235    label in the position corresponding to the value of the x-axis measure and indicator-axis measure. In addition to the text,
236    it is also possible choose the color assigned to the point in such position.
238    Parameters
239    ----------
240    experiment_folder : str
241        The folder that contains the experiment logs, MDP configurations and agent configurations.
242    text_f : Callable[[Tuple[str, str]], str]
243        The function that returns a text label for a given MDP class name and gin config index. For example,
244        ('DeepSeaEpisodic', 'prms_0') -> "DeepSeaEpisodic (0)".
245    color_f : Callable[[Tuple[str, str]], str]
246        The function that returns the color for the point in the position corresponding to a given MDP class name and
247        gin config index. For example, ('DeepSeaEpisodic', 'prms_0') -> "DeepSeaEpisodic (0)". By default, no particular
248        color is specified.
249    label_f : Callable[[Tuple[str, str]], str]
250        The function that returns the label to be put in the legend for the point in the position corresponding to a
251        given MDP class name and gin config index. For example, ('DeepSeaEpisodic', 'prms_0') -> "DeepSea family".
252        By default, the legend is not included in the plot.
253    ax : plt.Axes
254        The ax object where the plot will be put. By default, a new axis is created.
255    multiplicative_factor_xlim : float
256        The additional space to add on the right side of the figure. It can be useful to add space for the legend. By
257        default, it is set to one.
258    multiplicative_factor_ylim : float
259        The additional space to add on the top side of the figure. It can be useful to add space for the legend. By
260        default, it is set to one.
261    legend_ncol : int
262        The number of columns in the legend. By default, it is set to one.
263    underneath_x_label : str
264        Text to be added underneath the x_label. By default, no text is added.
265    set_ylabel : bool
266        If True, the indicator-label is set to the name of the indicator-axis measure. By default, the indicator-label is set.
267    set_legend : bool
268        If True, the legend is set. By default, the legend is set.
269    xaxis_measure : str
270        The code name of the hardness measures available in the package. Check `BaseMDP.get_available_hardness_measures()`
271        to get to know the available ones. By default, it is set to the diameter.
272    yaxis_measure : str
273        The code name of the hardness measures available in the package. Check `BaseMDP.get_available_hardness_measures()`.
274        to get to know the available ones. By default, it is set to the value norm.
275    fontsize : int
276        The font size for x and indicator labels. By default, it is set to :math:`22`.
277    fontsize_xlabel_underneath :
278        The font size for the text below the x label. By default, it is set to :math:`32`.
279    text_label_fontsize : int
280        The font size for the text labels of the points. By default, it is set to :math:`16`.
281    """
283    show = ax is None
284    if ax is None:
285        fig, ax = plt.subplots(1, 1, figsize=(8, 8))
286        set_ylabel = True
288    hardness_measures = get_hardness_measures_from_experiment_folder(
289        experiment_folder, (xaxis_measure, yaxis_measure)
290    )
291    texts = []
292    for k, r in hardness_measures.items():
293        texts.append(
294            ax.text(
295                r[xaxis_measure],
296                r[yaxis_measure],
297                text_f(k),
298                fontdict=dict(fontsize=text_label_fontsize),
299            )
300        )
301        ax.scatter(
302            r[xaxis_measure],
303            r[yaxis_measure],
304            500,
305            color=color_f(k),
306            label=label_f(k),
307            edgecolor="black",
308            linewidths=0.5,
309        )
311    ax.tick_params(labelsize=22)
312    if set_ylabel:
313        ax.set_ylabel(
314            yaxis_measure.capitalize().replace("_", " "),
315            fontdict=dict(fontsize=fontsize),
316            labelpad=10,
317        )
318    ax.set_xlabel(
319        xaxis_measure.capitalize().replace("_", " "),
320        fontdict=dict(fontsize=fontsize),
321        labelpad=15,
322        ha="center",
323    )
325    xlim = ax.get_xlim()
326    ax.set_xlim(xlim[0], xlim[1] * multiplicative_factor_xlim)
327    ylim = ax.get_ylim()
328    ax.set_ylim(ylim[0], ylim[1] * multiplicative_factor_ylim)
330    if type(underneath_x_label) == str:
331        ax.text(
332            np.mean(ax.get_xlim()),
333            ylim[0] - 0.28 * (ylim[1] - ylim[0]),
334            underneath_x_label,
335            fontdict=dict(fontsize=fontsize_xlabel_underneath),
336            ha="center",
337        )
339    if set_legend:
340        h, l = ax.get_legend_handles_labels()
341        if h:
342            ax.legend(ncol=legend_ncol)
344    plt.tight_layout()
345    adjust_text(
346        texts,
347        ax=ax,
348        expand_text=(1.05, 1.8),
349        expand_points=(1.05, 1.5),
350        only_move={"points": "indicator", "text": "xy"},
351        precision=0.0001,
352        lim=1000,
353    )
355    if show:
356        plt.tight_layout()
360def plot_indicator_in_hardness_space(
361    experiment_folder: str,
362    indicator: str = "normalized_cumulative_regret",
363    fontsize: int = 22,
364    cmap: str = "Reds",
365    fig_size=8,
366    text_label_fontsize=14,
367    savefig_folder: Optional[str] = "tmp",
368) -> "Figure":
369    """
370    for each agent config, it produces a plot that places the given indicator obtained by the agent config for each MDP
371    config in the position corresponding to the diameter and value norm of the MDP.
373    Parameters
374    ----------
375    experiment_folder : str
376        The path of the directory containing the experiment results.
377    indicator : str
378        is a string representing the performance indicator that is shown in the plot. Check `MDPLoop.get_indicators()`
379        to get a list of the available indicators. By default, the 'normalized_cumulative_regret' is used.
380    fontsize : int
381        The font size for x and indicator labels. By default, it is set to :math:`22`.
382    cmap : str
383        The code name for the color map to be used when plotting the indicator values. By default,
384        the 'Reds' color map is used.
385    fig_size : int
386        The size of the figures in the grid of plots. By default, it is set to :math:`8`.
387    text_label_fontsize : int
388        The font size for the text labels of the points. By default, it is set to :math:`14`.
389    savefig_folder : str
390        The folder where the figure will be saved. By default, the figure it is saved in a local folder with name 'tmp'.
391        If the directory does not exist, it is created.
393    Returns
394    -------
395    Figure
396        The matplotlib figure.
397    """
399    color_map =
400    _, df = get_latex_table_of_average_indicator(
401        experiment_folder,
402        indicator,
403        show_prm=True,
404        return_table=True,
405        mdps_on_row=False,
406    )
407    df_numerical = df.applymap(lambda s: float(re.findall("\d+\.\d+", s)[0]))
408    fig, axes = plt.subplots(
409        1, len(df.index), figsize=(len(df.index) * fig_size + 1, fig_size), sharey=True
410    )
411    if len(df.index) == 1:
412        axes = np.array([axes])
413    for i, (a, ax) in enumerate(zip(df.index, axes.tolist())):
414        plot_labels_on_benchmarks_hardness_space(
415            experiment_folder,
416            label_f=lambda x: None,
417            color_f=lambda x: color_map(
418                df_numerical.loc[a, _get_index(x)] / df_numerical.loc[a].max()
419            ),
420            text_f=lambda x: f"{_get_index(x)[0].replace('MiniGrid', 'MG-')} "
421            f"({(_get_index(x)[1].replace('prms_', ''))})",
422            # text_f=lambda x: "",
423            ax=ax,
424            fontsize=fontsize,
425            text_label_fontsize=text_label_fontsize,
426            underneath_x_label=f"({ascii_lowercase[i]}) {a[0]}",
427        )
428        # ax.set_title(
429        #     f"({ascii_lowercase[i]}) {a[0]}",
430        #     fontdict=dict(legend_fontsize=legend_fontsize + 4),
431        #     indicator=-0.28,
432        # )
434    plt.tight_layout()
436    if savefig_folder is not None:
437        os.makedirs(savefig_folder, exist_ok=True)
438        exp_name = os.path.basename(os.path.dirname(ensure_folder(experiment_folder)))
439        plt.savefig(
440            f"{ensure_folder(savefig_folder)}{indicator}_in_hard_space_{exp_name}.pdf",
441            bbox_inches="tight",
442        )
445    return fig
448def group_by_mdp_individual_plot(
449    experiment_folder: str,
450    ax,
451    measure: str,
452    mdp_class_name: str,
453    mdp_prms: str,
454    available_agents: List[Tuple[str, str]],
455    colors_dict_agents: Dict[Tuple[str, str], str],
456    standard_error: bool = False,
457    baselines=MDPLoop.get_baselines(),
459    """
460    plots the measure for the given agents and experiment fold in the given axes.
462    Parameters
463    ----------
464    experiment_folder : str
465        is the folder that contains the experiment logs, MDP configurations and agent configurations.
466    ax : plt.Axes
467        is where the plot will be shown.
468    measure : str
469        is a string representing the performance measure that is shown in the plot. Check
470        MDPLoop.get_indicators() to get a list of the available indicators.
471    mdp_prms : str
472        is a string that contains the mdp parameter gin config parameter, i.e. 'prms_0'.
473    mdp_class_name : str
474        is a string that contains the mdp class name.
475    available_agents : List[Tuple[str, str]]
476        is a list containing the agent gin config parameters and the agent class names.
477    colors_dict_agents : Dict[Tuple[str, str], str]
478        is a dict that assign to each agent gin config parameter and agent class name a different color.
479    """
480    mdp_code = mdp_prms + config.EXPERIMENT_SEPARATOR_PRMS + mdp_class_name
482    for available_agent in available_agents:
483        agent_code = (
484            available_agent[1] + config.EXPERIMENT_SEPARATOR_PRMS + available_agent[0]
485        )
486        agent_formatted_name = get_formatted_name(*available_agent)
487        df, n_seeds = get_logs_data(
488            experiment_folder, mdp_class_name, mdp_prms, *available_agent
489        )
491        for b in baselines:
492            y = measure.replace("cumulative_reward", "cumulative_expected_reward")
493            if b + "_" + y in MDPLoop.get_baseline_indicators():
494                sns.lineplot(
495                    x="steps",
496                    y=b + "_" + y,
497                    label=b.capitalize() + " agent",
498                    data=df,
499                    ax=ax,
500                    errorbar=None,
501                    color=MDPLoop.get_baselines_color_dict()[b],
502                    linestyle=MDPLoop.get_baselines_style_dict()[b],
503                    linewidth=2,
504                )
506        # We plot the baselines only once
507        baselines = []
509        add_time_exceed_sign_to_plot(
510            ax,
511            df,
512            colors_dict_agents[available_agent],
513            measure,
514            n_seeds,
515            experiment_folder,
516            mdp_code,
517            agent_code,
518        )
519        sns_ax = sns.lineplot(
520            x="steps",
521            y=measure,
522            label=agent_formatted_name,
523            data=df,
524            ax=ax,
525            errorbar="se" if standard_error else ("ci", 95),
526            color=colors_dict_agents[available_agent],
527        )
528        sns_ax.set_ylabel(" ".join(map(lambda x: x.capitalize(), measure.split("_"))))
