Skip to content

BreakDownPlotter

BreakDownPlotter

Source code in src/psPlotKit/data_plotter/ps_break_down_plotter.py
class BreakDownPlotter:
    def __init__(
        self,
        PsData,
        save_location=None,
        save_folder=None,
        save_name=None,
        show_fig=True,
    ):

        self.save_location = create_save_location(save_location, save_folder)
        self.show_fig = show_fig
        self.select_data_key_list = []
        self.PsData = PsData
        self.define_plot_styles()
        self.line_indexes = {}
        self.line_groups = None
        self.xunit = None
        self.yunit = None
        self.xdata_label = None
        self.ydata_label = None
        self.save_name = save_name
        self.hatch_groups = {}
        self.area_groups = {}

    def _select_data(self, xkeys, ykeys):
        self.PsData.select_data(xkeys, require_all_in_dir=False)
        self.PsData.select_data(ykeys, add_to_existing=True, require_all_in_dir=False)

    def define_hatch_groups(self, groups=None):
        self.hatch_groups = groups

    def define_area_groups(self, groups):
        self.area_groups = groups

    def _get_color(self, label):
        if isinstance(self.line_colors, dict):
            return self.line_colors[label]
        else:
            idx = None
            for key in label:
                if key in self.line_indexes:
                    idx = self.line_indexes[key]["idx"]
                    break
            if idx == None:
                # _logger.info("Did not find line color group")

                auto_label = "auto_{}".format(label)

                if auto_label in self.line_indexes:
                    self.line_indexes[auto_label]["idx"] += 1
                    idx = self.line_indexes[auto_label]["idx"]

                else:
                    idx = 0
                    self.line_indexes[auto_label] = {"idx": 0}
            if isinstance(idx, int):
                color = self.line_colors[(idx % len(self.line_colors))]
            else:
                color = idx
            return color

    def define_plot_styles(self, hatch_options=None, line_color_list=None):
        if hatch_options is None:
            self.hatch_options = ["", "////", "///\\\\"]
        if line_color_list is None:
            self.line_colors = [
                "#a6cee3",
                "#1f78b4",
                "#b2df8a",
                "#33a02c",
                "#fb9a99",
                "#e31a1c",
                "#fdbf6f",
                "#ff7f00",
                "#cab2d6",
                "#6a3d9a",
                "#ffff99",
            ]
        else:
            self.line_colors = line_color_list

    def _get_data(self, data, key):
        data_keys = []
        data_list = []
        for dkey, data in data.items():
            if key in dkey:

                if isinstance(dkey, tuple):
                    skeys = []
                    for k in dkey:
                        if k != key:
                            skeys.append(k)
                else:
                    skeys = [dkey]
                data_keys.append(tuple(skeys))
                data_list.append(data)
        return data_keys, data_list

    def _get_axis_label(self, label, units):
        return "{} ({})".format(label, units)

    def _get_ydata(self, selected_keys, ydata):
        ykey_data = []
        for skey in selected_keys:
            if isinstance(ydata, list) or isinstance(ydata, tuple):
                all_test = all(ykey in str(skey) for ykey in ydata)
            else:
                all_test = ydata in str(skey)
            if all_test:
                ykey_data.append(skey)
        return ykey_data

    def _replace_key(self, skey, xdata):
        dir_key = list(skey)
        dir_key[-1] = xdata
        return tuple(dir_key)

    def check_key_in_dir(self, udir, key):
        for d in udir:
            if isinstance(d, str):
                if d == key:
                    return True
            else:
                for di in d:
                    if key == di:
                        return True
        return False

    def _get_group_options(self, selected_keys, xdata, ydata):
        self.plot_areas = {}
        # self.area_groups = {}
        for skey in self._get_ydata(selected_keys, ydata):
            opts = None
            key = None
            color = None
            for i, key in enumerate(self.hatch_groups):
                if key in str(skey):
                    opts = self.hatch_groups[key]
                    if opts.get("label") == None:
                        opts["label"] = key
                    if opts.get("color") == None:
                        opts["color"] = "white"
                    if opts.get("hatch") == None:
                        opts["hatch"] = self.hatch_options[i]

                    self.hatch_groups[key] = copy.deepcopy(opts)
                    # if key not in self.line_indexes:
                    #     self.line_indexes[key] = {"idx": 0, "auto": True}
            for akey in self.area_groups:
                _label = None
                if isinstance(akey, dict):
                    akey, item = list(akey.items())[0]
                    if isinstance(item, dict):
                        if "label" in item:
                            _label = item["label"]
                        if "color" in item:
                            color = item["color"]
                    else:
                        _label = item
                if self.check_key_in_dir(skey, akey):
                    if _label is None:
                        _label = akey
                    plot_label = _label
                    cur_line = {}
                    raw_ydata = self.selected_data[skey]
                    raw_xdata = self.selected_data[self._replace_key(skey, xdata)]
                    cur_line["ydata"] = raw_ydata.data
                    cur_line["xdata"] = raw_xdata.data

                    if self.xunit == None:
                        self.xunit = raw_xdata.mpl_units
                    if self.xdata_label == None:
                        self.xdata_label = raw_xdata.data_label
                    if self.yunit == None:
                        self.yunit = raw_ydata.mpl_units
                    if self.ydata_label == None:
                        self.ydata_label = raw_ydata.data_label
                    if opts != None:
                        for key, val in opts.items():
                            cur_line[key] = val

                    if self.hatch_groups != {}:
                        for h_key in self.hatch_groups:
                            # print("ss", h_key, akey, skey)
                            if h_key in str(skey):

                                _label = tuple([h_key, akey])
                                plot_label.replace(h_key, "")
                                if color is None:
                                    color = self._get_color(h_key)
                                cur_line["color"] = color
                                cur_line["label"] = plot_label
                                self.plot_areas[_label] = cur_line
                    else:
                        if color is None:
                            color = self._get_color("no_groups")
                        cur_line["color"] = color
                        cur_line["label"] = plot_label
                        self.plot_areas[akey] = cur_line
        self.plot_order = []
        for akey in self.area_groups:
            if isinstance(akey, dict):
                akey, item = list(akey.items())[0]
            for key in self.plot_areas.keys():
                if akey in key:
                    self.plot_order.append(key)
        # assert False

    def plotbreakdown(
        self,
        xdata,
        ydata,
        axis_options=None,
        generate_figure=True,
        legend_loc="upper left",
        legend_cols=2,
        fig_options={},
    ):

        self._select_data(xdata, ydata)
        self.selected_data = self.PsData.get_selected_data()
        self.selected_data.display()
        self.generate_groups_lines = self._get_group_options(
            self.selected_data.keys(), xdata, ydata
        )
        self.fig_options = fig_options
        self.index = 0
        if axis_options is None:
            self.axis_options = {}
        else:
            self.axis_options = axis_options
        if self.axis_options.get("xlabel") == None:
            self.axis_options["xlabel"] = self._get_axis_label(
                self.xdata_label, self.xunit
            )  # all lines shold share units
        if self.axis_options.get("ylabel") == None:
            self.axis_options["ylabel"] = self._get_axis_label(
                self.ydata_label, self.yunit
            )  # all lines shold share units

        self.plot_imported_data()

        if (
            generate_figure
        ):  # TODO: other plotters call this generate_plot, should make this consistent
            self.generate_figure(loc=legend_loc, cols=legend_cols)

    def plot_imported_data(self):
        if "fig_object" in self.fig_options:
            self.fig = self.fig_options.get("fig_object")
        else:
            self.fig = figureGenerator()
            self.fig.init_figure(**self.fig_options)
        plotted_legend = []
        for group, items in self.hatch_groups.items():
            self.fig.plot_area([], [], **items)
        old_data = 0
        current_data = None

        for linelabel in self.plot_order:
            line = self.plot_areas[linelabel]
            if line.get("label") in plotted_legend:
                line.pop("label")
            else:
                plotted_legend.append(line["label"])
            line["y2data"] = old_data
            if current_data is None:
                current_data = line["ydata"]
            else:
                current_data = line["ydata"] + old_data
            line["ydata"] = current_data
            if "ax_idx" in self.fig_options:
                line["ax_idx"] = self.fig_options.get("ax_idx")
            self.fig.plot_area(**line)
            old_data = line["ydata"]

    def generate_figure(self, loc="upper left", cols=2):
        if "ax_idx" in self.fig_options:
            self.axis_options["ax_idx"] = self.fig_options["ax_idx"]
        self.fig.set_axis(**self.axis_options)
        self.fig.add_legend(loc=loc, ncol=cols)

        if self.save_name is not None:
            self.fig.save(self.save_location, self.save_name)

        if self.show_fig:
            self.fig.show()

        self.fig.close()