
  1import re
  2from typing import List, Tuple, Union
  4import numpy as np
  5import pandas as pd
  7from colosseum.analysis.utils import format_indicator_name, get_n_failed_interactions
  8from colosseum.analysis.utils import get_available_mdps_agents_prms_and_names
  9from colosseum.analysis.utils import get_formatted_name, get_logs_data
 10from colosseum.experiment.agent_mdp_interaction import MDPLoop
 11from colosseum.utils.formatter import clear_agent_mdp_class_name
 14def get_latex_table_of_average_indicator(
 15    experiment_folder: str,
 16    indicator: str,
 17    show_prm: bool = False,
 18    divide_by_total_number_of_time_steps: bool = True,
 19    mdps_on_row: bool = True,
 20    print_table: bool = False,
 21    return_table: bool = False,
 22) -> Union[str, Tuple[str, pd.DataFrame]]:
 23    r"""
 24    produces a latex table whose entries are the averages over the seeds of the indicator given in input for the
 25    results in the experiment folder.
 27    Parameters
 28    ----------
 29    experiment_folder : str
 30        The folder that contains the experiment logs, MDP configurations and agent configurations.
 31    indicator : str
 32        The code name of the performance indicator that will be shown in the plot. Check `MDPLoop.get_indicators()` to
 33        get a list of the available indicators.
 34    show_prm : bool
 35        If True, the gin parameter config index is shown next to the agent/MDP class name. By default, it is not shown.
 36    divide_by_total_number_of_time_steps : bool
 37        If True, the value of the indicator is divided by the total number of time steps of agent/MDP interaction. By
 38        default, it is divided.
 39    mdps_on_row : bool
 40        If True, MDPs are shown in the rows. If False, agents are shown on the row indices. By default, MDPs are shown
 41        on row indices.
 42    print_table : bool
 43        If True, the table is printed.
 44    return_table : bool
 45        If True, in addition to the string with the :math:`\LaTeX` table, the pd.DataFrame is also returned. By default,
 46        only the string with the :math:`\LaTeX` table is returned.
 48    Returns
 49    -------
 50    Union[str, Tuple[str, pd.DataFrame]]
 51        The :math:`\LaTeX` table, and optionally the pd.DataFrame associated.
 52    """
 54    available_mdps, available_agents = get_available_mdps_agents_prms_and_names(
 55        experiment_folder
 56    )
 58    table = pd.DataFrame(
 59        columns=pd.MultiIndex.from_tuples([("MDP", "")] + available_agents), dtype=str
 60    )
 61    agent_average_performance = {a: [] for a in available_agents}
 62    for i, (mdp_class_name, mdp_prm) in enumerate(available_mdps):
 63        row = [mdp_class_name]
 64        for k, (agent_class_name, agent_prm) in enumerate(available_agents):
 65            df, n_seeds = get_logs_data(
 66                experiment_folder,
 67                mdp_class_name,
 68                mdp_prm,
 69                agent_class_name,
 70                agent_prm,
 71            )
 72            values = df.loc[df.steps == df.steps.max(), indicator]
 73            if divide_by_total_number_of_time_steps:
 74                values /= df.steps.max() + 1
 75            row.append(f"${values.mean():.2f}\\pm{values.std():4.2f}$")
 76            agent_average_performance[agent_class_name, agent_prm].append(values.mean())
 77        if show_prm:
 78            row[0] = get_formatted_name(mdp_class_name, mdp_prm)
 80        scores = [float(re.findall(r"\$[0-9].[0-9]+", r)[0][1:]) for r in row[1:]]
 81        if "regret" in indicator or "steps_per_second" in indicator:
 82            best_scores = "$" + f"{(min(scores)):.2f}"
 83        elif "reward" in indicator:
 84            best_scores = "$" + f"{(max(scores)):.2f}"
 85        else:
 86            raise ValueError(f"I'm not sure whether min or max is best for {indicator}")
 87        for k in range(1, len(row)):
 88            row[k] = row[k].replace(best_scores, "$\\mathbf{" + best_scores[1:] + "}")
 90        row[0] = clear_agent_mdp_class_name(row[0])
 91        table.loc[len(table)] = row
 93    row = [r"\textit{Average}"]
 94    for c in table.columns[1:]:
 95        values = np.array(agent_average_performance[c])
 96        row.append(f"${values.mean():.2f}\\pm{values.std():4.2f}$")
 98    scores = [float(re.findall(r"\$[0-9].[0-9]+", r)[0][1:]) for r in row[1:]]
 99    if "regret" in indicator or "steps_per_second" in indicator:
100        best_scores = "$" + f"{(min(scores)):.2f}"
101    elif "reward" in indicator:
102        best_scores = "$" + f"{(max(scores)):.2f}"
103    else:
104        raise ValueError(f"I'm not sure whether min or max is best for {indicator}")
105    for k in range(1, len(row)):
106        row[k] = row[k].replace(best_scores, "$\\mathbf{" + best_scores[1:] + "}")
107    table.loc[len(table)] = row
109    table.columns = pd.MultiIndex.from_tuples(
110        [(clear_agent_mdp_class_name(n), p) for n, p in table.columns.values]
111    )
112    table = table.set_index("MDP")
113    table_lat = table.copy()
114    if show_prm:
115        table_lat.index = [
116            c.replace(c.split(" ")[0], " " * len(c.split(" ")[0]))
117            if i > 0 and c.split(" ")[0] == table_lat.index[i - 1].split(" ")[0]
118            else c
119            for i, c in enumerate(table_lat.index)
120        ]
121    else:
122        table_lat.index = [
123            "" if i > 0 and c == table_lat.index[i - 1] else c
124            for i, c in enumerate(table_lat.index)
125        ]
127    if not mdps_on_row:
128        table = table.T
129        table.columns = pd.MultiIndex.from_tuples(
130            [(clear_agent_mdp_class_name(n), p) for n, p in available_mdps]
131            + [(r"\textit{Average}", "")]
132        )
133        table_lat = table.copy()
134        table_lat.index = [
135            "" if i > 0 and c == table_lat.index[i - 1][0] else c
136            for i, (c, p) in enumerate(table_lat.index.values)
137        ]
138 = None
140    if print_table:
141        with pd.option_context(
142            "display.max_rows", None, "display.max_columns", None, "display.width", 500
143        ):
144            print(table)
146    table_lat = table_lat.to_latex(escape=False).replace(
147        r"\bottomrule", r"\arrayrulecolor{black!15}\midrule%"
148    )
149    if not show_prm:
150        table_lat = table_lat.split("\n")
151        table_lat.pop(3)
152        table_lat = "\n".join(table_lat)
154    # Add midrules between agents/mdps parameters
155    table_lat = table_lat.split("\n")
156    row_indices = (available_mdps if mdps_on_row else available_agents) + [
157        (r"\textit{Average}", "")
158    ]
159    for i, (c, p) in reversed(list(enumerate(row_indices))):
160        if i > 0 and c != row_indices[i - 1][0]:
161            table_lat.insert(
162                i + 4,
163                r"\arrayrulecolor{black!"
164                + f"{30 if 'Average' in row_indices[i][0] else 15}"
165                + "}\midrule%",
166            )
167    table_lat = "\n".join(table_lat).replace("{l}", "{c}").replace("MiniGrid", "MG-")
169    # Centering the columns with numbers
170    columns_labels = re.findall(r"\{l+\}", table_lat)[0]
171    table_lat = table_lat.replace(
172        "l" * (len(columns_labels) - 2), "l" + "c" * (len(columns_labels) - 3)
173    )
175    if return_table:
176        return table_lat, table
177    return table
180def get_latex_table_of_indicators(
181    experiment_folder: str,
182    indicators: List[str],
183    show_prm_agent: bool = False,
184    divide_by_total_number_of_time_steps: bool = True,
185    print_table: bool = False,
186    show_prm_mdp=True,
187) -> str:
188    r"""
189    produces a latex table whose entries are the averages over the seeds of the indicator given in input for the
190    results in the experiment folder.
192    Parameters
193    ----------
194    experiment_folder : str
195        The folder that contains the experiment logs, MDP configurations and agent configurations.
196    indicators : List[str]
197        The list of strings containing the performance indicators that will be shown in the plot. Check
198        `MDPLoop.get_indicators()` to get a list of the available indicators.
199    show_prm_agent : bool
200        If True, the gin parameter config index is shown next to the agent class name. By default, it is not shown.
201    divide_by_total_number_of_time_steps : bool
202        If True, the value of the indicator is divided by the total number of time steps of agent/MDP interaction. By
203        default, it is divided.
204    print_table : bool
205        If True, the table is printed.
206    show_prm_mdp : bool
207        If True, the gin parameter config index is shown next to the MDP class name. By default, it is shown.
209    Returns
210    -------
211    str
212        The :math:`\LaTeX` table.
213    """
215    assert all(
216        ind in MDPLoop.get_indicators() for ind in indicators
217    ), f"I received an invalid indicator, the available indicators are: {MDPLoop.get_indicators()}"
219    available_mdps, available_agents = get_available_mdps_agents_prms_and_names(
220        experiment_folder
221    )
222    # available_agents.insert(0, available_agents[0])
224    table = pd.DataFrame(
225        columns=[
226            "MDP",
227            "Agent",
228            *map(format_indicator_name, indicators),
229            r"\# completed seeds",
230        ],
231        dtype=str,
232    )
233    for i, (mdp_class_name, mdp_prm) in enumerate(available_mdps):
234        for j, (agent_class_name, agent_prm) in enumerate(available_agents):
235            row = [mdp_class_name, agent_class_name]
237            df, n_seeds = get_logs_data(
238                experiment_folder,
239                mdp_class_name,
240                mdp_prm,
241                agent_class_name,
242                agent_prm,
243            )
245            if "Continuous" in agent_class_name:
246                df.normalized_cumulative_expected_reward = (
247                    df.steps.max()
248                    * (
249                        df.cumulative_expected_reward
250                        - df.worst_cumulative_expected_reward
251                    )
252                    / (
253                        df.optimal_cumulative_expected_reward
254                        - df.worst_cumulative_expected_reward
255                    )
256                )
257                df.normalized_cumulative_reward = (
258                    df.steps.max()
259                    * (df.cumulative_reward - df.worst_cumulative_expected_reward)
260                    / (
261                        df.optimal_cumulative_expected_reward
262                        - df.worst_cumulative_expected_reward
263                    )
264                )
266            df[np.isclose(df, 0)] = 0
268            values = df.loc[df.steps == df.steps.max(), indicators]
269            if divide_by_total_number_of_time_steps:
270                values /= df.steps.max() + 1
271            row += [f"${v.mean():.2f}\\pm{v.std():.2f}$" for v in values.values.T]
273            n_failed = get_n_failed_interactions(
274                experiment_folder,
275                mdp_class_name,
276                mdp_prm,
277                agent_class_name,
278                agent_prm,
279            )
280            row.append(f"${n_seeds - n_failed}/{n_seeds}$")
282            if show_prm_mdp:
283                row[0] = get_formatted_name(mdp_class_name, mdp_prm).replace(
284                    "MiniGrid", "MG-"
285                )
286            if show_prm_agent:
287                row[1] = get_formatted_name(agent_class_name, agent_prm)
288            row[0] = clear_agent_mdp_class_name(row[0])
289            row[1] = clear_agent_mdp_class_name(row[1])
290            table.loc[len(table)] = row
292    table.MDP = [
293        "" if i > 0 and c == table.MDP[i - 1] else c for i, c in enumerate(table.MDP)
294    ]
295    table.Agent = [
296        "" if i > 0 and c == table.Agent[i - 1] else c
297        for i, c in enumerate(table.Agent)
298    ]
299    table = table.set_index(["MDP", "Agent"])
301    if print_table:
302        with pd.option_context(
303            "display.max_rows", None, "display.max_columns", None, "display.width", 500
304        ):
305            print(table)
307    short_rule_indices = []
308    long_rule_indices = []
309    for i, (mdp_i, agent_i) in enumerate(table.index):
310        if i > 0 and agent_i != "" and agent_i != table.index[i - 1][1]:
311            if mdp_i == "":
312                short_rule_indices.append(i + 5)
313            else:
314                long_rule_indices.append(i + 5)
316    table_columns_len = len(table.columns)
317    table = table.to_latex(escape=False).split("\n")
318    for i in reversed(range(len(table))):
319        if i in long_rule_indices:
320            table.insert(
321                i,
322                r"\arrayrulecolor{black!15}\cmidrule{"
323                + f"1-{1 + table_columns_len}"
324                + "}",
325            )
326        elif i in short_rule_indices:
327            table.insert(
328                i,
329                r"\arrayrulecolor{black!15}\cmidrule{"
330                + f"2-{1 + table_columns_len}"
331                + "}",
332            )
333    return "\n".join(table)
