Skip to content

LinePlotter

LinePlotter

Source code in src/psPlotKit/data_plotter/ps_line_plotter.py
class LinePlotter:
    def __init__(
        self, PsData, save_location="", save_folder=None, save_name=None, show_fig=True
    ):
        self.save_location = create_save_location(save_location, save_folder)
        self.show_fig = show_fig
        self.select_data_key_list = []
        self.PsData = PsData
        self.define_line_colors()
        self.line_indexes = {}
        self.line_groups = {}
        self.xunit = None
        self.yunit = None
        self.xdata_label = None
        self.ydata_label = None
        self.save_name = save_name
        self.data_index_to_label = {}

    def _select_data(self, xdata, ydata):
        self.PsData.select_data(xdata, require_all_in_dir=False)
        self.PsData.select_data(ydata, require_all_in_dir=False, add_to_existing=True)

    def define_line_groups(self, line_groups=None):
        self.line_groups = line_groups
        self.line_indexes = {}

    def define_index_labels(self, index_labels):
        self.data_index_to_label = index_labels

    def specify_line_colors(self, line_colors):
        self.line_indexes = {}
        for l in line_colors:
            self.line_indexes[l] = {"idx": line_colors[l], "auto": False}

    def _get_color(self, label, count_idx=None, single_group=False):
        if isinstance(self.line_colors, dict):
            return self.line_colors[label]
        else:
            idx = None
            if isinstance(label, str):
                label = [label]
            for key in label:
                if key in self.line_indexes:
                    idx = self.line_indexes[key]["idx"]
                    # print("lk", key, self.line_indexes[key])
                    break
            # print("idx", idx)
            if idx == None or single_group == False:
                if single_group:
                    auto_label = "single_group"
                else:
                    auto_label = "auto_{}".format(label)

                if auto_label in self.line_indexes:
                    if str(count_idx) in self.line_indexes["count_idxs"]:
                        for idx, l in enumerate(self.line_indexes["count_idxs"]):
                            if str(count_idx) == str(l):
                                break
                    else:
                        self.line_indexes[auto_label]["idx"] += 1
                        idx = self.line_indexes[auto_label]["idx"]
                        self.line_indexes["count_idxs"].append(count_idx)
                else:
                    idx = 0
                    self.line_indexes[auto_label] = {"idx": 0}
                    if "count_idxs" not in self.line_indexes:
                        self.line_indexes["count_idxs"] = [count_idx]
                    else:
                        if str(count_idx) in self.line_indexes["count_idxs"]:
                            for idx, l in enumerate(self.line_indexes["count_idxs"]):
                                if str(count_idx) == str(l):
                                    break

                # print(
                #     "color auto",
                #     auto_label,
                #     self.line_indexes[auto_label],
                #     idx,
                #     count_idx,
                #     self.line_indexes["count_idxs"],
                # )
            if isinstance(idx, int):
                color = self.line_colors[idx]
            else:
                color = idx
            return color

    def define_line_colors(self, line_color_list=None):
        if line_color_list is None:
            self.line_colors = [
                "#a6cee3",
                "#1f78b4",
                "#b2df8a",
                "#33a02c",
                "#fb9a99",
                "#e31a1c",
                "#fdbf6f",
                "#ff7f00",
                "#cab2d6",
                "#6a3d9a",
                "#ffff99",
            ]
        else:
            self.line_colors = line_color_list

    def _get_data(self, data, key):
        data_keys = []
        data_list = []
        for dkey, data in data.items():
            if key in dkey:

                if isinstance(dkey, tuple):
                    skeys = []
                    for k in dkey:
                        if k != key:
                            skeys.append(k)
                else:
                    skeys = [dkey]
                data_keys.append(tuple(skeys))
                data_list.append(data)
        return data_keys, data_list

    def _get_axis_label(self, label, units):
        return "{} ({})".format(label, units)

    def _get_ydata(self, selected_keys, ydata):
        ykey_data = []
        for skey in selected_keys:
            if isinstance(ydata, list) or isinstance(ydata, tuple):
                all_test = all(ykey in str(skey) for ykey in ydata)
            else:
                all_test = ydata in str(skey)
            if all_test:
                ykey_data.append(skey)
        return ykey_data

    def _replace_key(self, skey, xdata):
        dir_key = list(skey)
        dir_key[-1] = xdata
        return tuple(dir_key)

    def check_key_in_dir(self, udir, key):
        for d in udir:
            if isinstance(d, str):
                if d == key:
                    return True
            else:
                for di in d:
                    if key == di:
                        return True
        return False

    def _get_group_options(self, selected_keys, xdata, ydata):
        self.plot_lines = {}

        # print("line_groups", self.line_groups)
        for skey in self._get_ydata(selected_keys, ydata):
            # print("skey", skey)
            opts = None
            key = None
            for key in self.line_groups:
                if str(key) in str(skey):
                    opts = self.line_groups[key]
                    if opts.get("label") == None:
                        opts["label"] = key
                    if opts.get("color") == None and opts.get("marker") == None:
                        opts["color"] = "black"
                    if opts.get("marker") == None:
                        opts["marker"] = "o"
                    if key not in self.line_indexes:
                        self.line_indexes[key] = {"idx": 0, "auto": True}
                    # print("line_indexes", self.line_indexes)
            for sk in skey:
                # print("sk", sk)
                if isinstance(ydata, list) or isinstance(ydata, tuple):
                    all_test = all(ykey in str(skey) for ykey in ydata)
                else:
                    all_test = ydata in str(sk)
                if all_test:
                    _label = None

                    cur_line = {}

                    for key, item in self.data_index_to_label.items():
                        # print(key, key in skey)
                        if str(key) in str(skey):
                            if isinstance(item, dict):
                                _label = item["label"]
                                if item.get("marker") is not None:
                                    cur_line["marker"] = item.get("marker")
                                if item.get("markersize") is not None:
                                    cur_line["markersize"] = item.get("markersize")
                                if item.get("color") is not None:
                                    cur_line["color"] = item.get("color")
                            else:
                                _label = item
                    if _label is None:
                        if isinstance(sk, tuple):
                            _label = list(sk)[:]
                            if isinstance(ydata, list) or isinstance(ydata, tuple):
                                for yd in ydata:
                                    if yd in _label:
                                        _label.remove(yd)

                            else:
                                if ydata in _label:
                                    _label.remove(ydata)
                            if len(_label) == 1:
                                _label = _label[0]
                            if isinstance(_label, str) == False:
                                _label = " ".join(map(str, _label))
                        else:
                            _label = sk

                    plot_label = _label
                    raw_ydata = self.selected_data[skey]
                    raw_xdata = self.selected_data[self._replace_key(skey, xdata)]
                    cur_line["ydata"] = raw_ydata.data
                    cur_line["xdata"] = raw_xdata.data

                    if self.xunit == None:
                        self.xunit = raw_xdata.mpl_units
                    if self.xdata_label == None:
                        self.xdata_label = raw_xdata.data_label
                    if self.yunit == None:
                        self.yunit = raw_ydata.mpl_units
                    if self.ydata_label == None:
                        self.ydata_label = raw_ydata.data_label
                    if opts != None:
                        for key, val in opts.items():
                            cur_line[key] = val

                    if self.line_groups != {}:
                        for g_key in self.line_groups:
                            if str(g_key) in str(skey):

                                plot_label.replace(str(g_key), "")
                                # print(cur_line.get("color"), "curcolor")
                                if (
                                    cur_line.get("color") is None
                                    and cur_line.get("color") != "black"
                                ):
                                    cur_line["color"] = self._get_color(
                                        g_key, count_idx=_label
                                    )
                                    # print("getting coor")
                                _label = tuple([g_key, _label])
                    else:
                        cur_line["color"] = self._get_color(
                            plot_label, single_group=True
                        )
                        if cur_line.get("marker") == None:
                            cur_line["marker"] = "o"
                    cur_line["label"] = plot_label
                    self.plot_lines[_label] = cur_line
                    break
        # print("line_groups", self.line_groups)
        # print("plot_lines", self.plot_lines)

    def plot_line(
        self, xdata, ydata, axis_options=None, fig_options={}, generate_plot=True
    ):

        self._select_data(xdata, ydata)
        self.selected_data = self.PsData.get_selected_data()
        self.selected_data.display()
        self.generate_groups_lines = self._get_group_options(
            self.selected_data.keys(), xdata, ydata
        )
        self.index = 0
        if axis_options is None:
            self.axis_options = {}
        else:
            self.axis_options = axis_options
        if self.axis_options.get("xlabel") == None:
            self.axis_options["xlabel"] = self._get_axis_label(
                self.xdata_label, self.xunit
            )  # all lines shold share units
        if self.axis_options.get("ylabel") == None:
            self.axis_options["ylabel"] = self._get_axis_label(
                self.ydata_label, self.yunit
            )  # all lines shold share units

        self.plot_imported_data(fig_options)

        if generate_plot:
            self.generate_figure()

        if self.save_name is not None:
            self.fig.save(self.save_location, self.save_name)

        if self.show_fig:
            self.fig.show()

        self.fig.close()

    def plot_imported_data(self, fig_options):
        if "fig_object" in fig_options:
            self.fig = fig_options.get("fig_object")
        else:
            self.fig = FigureGenerator()

            self.fig.init_figure(**fig_options)
        plotted_legend = []
        # print("gen linegroups", self.line_groups)
        for group, items in self.line_groups.items():
            if "ax_idx" in fig_options:
                items["ax_idx"] = fig_options.get("ax_idx")
            if items.get("color") == None:
                items["color"] = "black"
            self.fig.plot_line([], [], **items)
        for linelabel, line in self.plot_lines.items():
            if "ax_idx" in fig_options:
                line["ax_idx"] = fig_options.get("ax_idx")
            if line.get("label") in plotted_legend:
                line.pop("label")
            else:
                plotted_legend.append(line["label"])
            # print(line)
            self.fig.plot_line(**line)

    def generate_figure(self):
        self.fig.set_axis(**self.axis_options)
        self.fig.add_legend(**self.axis_options)