Module fri.plot
View Source
# matplotlib.use("TkAgg")
import matplotlib.cm as cm
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import numpy as np
from scipy.cluster.hierarchy import dendrogram
# Get a color for each relevance type
color_palette_3 = cm.Set1([0, 1, 2], alpha=0.8)
def plot_relevance_bars(
    ax, ranges, ticklabels=None, classes=None, numbering=True, tick_rotation=30
):
    """
    Parameters
    ----------
    ax:
        axis which the bars get drawn on
    ranges:
        the 2d array of floating values determining the lower and upper bounds of the bars
    ticklabels: (optional)
        labels for each feature
    classes: (optional)
        relevance class for each feature, determines color
    numbering: bool
        Add feature index when using ticklabels
    tick_rotation:  int
        Amonut of rotation of ticklabels for easier readability.
    """
    N = len(ranges)
    # Ticklabels
    if ticklabels is None:
        ticks = np.arange(N) + 1
    else:
        ticks = list(ticklabels)
        if numbering:
            for i in range(N):
                ticks[i] += " - {}".format(i + 1)
    # Interval sizes
    ind = np.arange(N) + 1
    width = 0.6
    upper_vals = ranges[:, 1]
    lower_vals = ranges[:, 0]
    height = upper_vals - lower_vals
    # Minimal height to make very small intervals visible
    height[height < 0.001] = 0.001
    # Bar colors
    if classes is None:
        new_classes = np.zeros(N).astype(int)
        color = [color_palette_3[c.astype(int)] for c in new_classes]
    else:
        color = [color_palette_3[c.astype(int)] for c in classes]
    # Plot the bars
    bars = ax.bar(
        ind,
        height,
        width,
        bottom=lower_vals,
        tick_label=ticks,
        align="center",
        edgecolor=["black"] * N,
        linewidth=1.3,
        color=color,
    )
    ax.set_xticklabels(ticks)
    if ticklabels is not None:
        ax.set_xticklabels(ax.get_xticklabels(), rotation=tick_rotation, ha="right")
    # ax.tick_params(rotation="auto")
    # Limit the y range to 0,1 or 0,L1
    ax.set_ylim([0, max(ranges[:, 1]) * 1.1])
    ax.set_ylabel("relevance")
    ax.set_xlabel("feature")
    if classes is not None:
        relevance_classes = ["Irrelevant", "Weakly relevant", "Strongly relevant"]
        patches = []
        for i, rc in enumerate(relevance_classes):
            patch = mpatches.Patch(color=color_palette_3[i], label=rc)
            patches.append(patch)
        ax.legend(handles=patches)
    return bars
def plotIntervals(ranges, ticklabels=None, invert=False, classes=None):
    # Figure Parameters
    fig = plt.figure()
    ax = fig.add_subplot(111)
    out = plot_relevance_bars(ax, ranges, ticklabels=ticklabels, classes=classes)
    fig.autofmt_xdate()
    # Invert the xaxis for cases in which the comparison with other tools
    if invert:
        plt.gca().invert_xaxis()
    return fig
def plot_dendrogram_and_intervals(
    intervals, linkage, figsize=(13, 7), ticklabels=None, classes=None, **kwargs
):
    fig, (ax2, ax) = plt.subplots(2, 1, figsize=figsize)
    # Top dendrogram plot
    d = dendrogram(
        linkage,
        color_threshold=0,
        leaf_rotation=0.0,  # rotates the x axis labels
        leaf_font_size=12.0,  # font size for the x axis labels
        ax=ax2,
    )
    # Get index determined through linkage method and dendrogram
    rearranged_index = d["leaves"]
    ranges = intervals[rearranged_index]
    if ticklabels is None:
        ticks = np.array(rearranged_index)
        ticks += 1  # Index starting at 1
    else:
        ticks = list(ticklabels[rearranged_index])
        for i in range(len(intervals)):
            ticks[i] += " - {}".format(rearranged_index[i] + 1)
    plot_relevance_bars(
        ax,
        ranges,
        ticklabels=ticks,
        classes=classes[rearranged_index] if classes is not None else None,
        numbering=False,
        **kwargs,
    )
    fig.subplots_adjust(hspace=0)
    ax.margins(x=0)
    ax2.set_xticks([])
    ax2.margins(x=0)
    plt.tight_layout()
    return fig
def plot_intervals(model, ticklabels=None):
    """Plot the relevance intervals.
    Parameters
    ----------
    model : FRI model
        Needs to be fitted before.
    ticklabels : list of str, optional
        Strs for ticklabels on x-axis (features)
    """
    if model.interval_ is not None:
        plotIntervals(
            model.interval_, ticklabels=ticklabels, classes=model.relevance_classes_
        )
    else:
        print("Intervals not computed. Try running fit() function first.")
def plot_lupi_intervals(model, ticklabels=None, lupi_ticklabels=None):
    """Plot the relevance intervals.
    Parameters
    ----------
    model : FRI model
        Needs to be fitted before.
    ticklabels : list of str, optional
        Strs for ticklabels on x-axis (features)
    lupi_ticklabels : list of str, optional
        Strs for lupi ticklabels on x-axis (lupi features)
    """
    n_features = model.interval_.shape[0] - model.lupi_features_
    data_interval_ = model.interval_[0:n_features, :]
    lupi_interval_ = model.interval_[n_features:, :]
    data_relevance_classes_ = model.relevance_classes_[0:n_features]
    lupi_relevance_classes_ = model.relevance_classes_[n_features:]
    if model.interval_ is not None:
        plotIntervals(
            data_interval_, ticklabels=ticklabels, classes=data_relevance_classes_
        )
        plotIntervals(
            lupi_interval_, ticklabels=lupi_ticklabels, classes=lupi_relevance_classes_
        )
    else:
        print("Intervals not computed. Try running fit() function first.")
#
# def interactive_scatter_embed(embedding, mode="markers", txt=None):
#     # TODO: extend method
#     import plotly.graph_objs as go
#     from plotly.offline import init_notebook_mode, iplot
#     init_notebook_mode(connected=True)
#     # Create a trace
#     trace = go.Scatter(
#         x=embedding[:, 0],
#         y=embedding[:, 1],
#         mode=mode,
#         text=txt if mode is "text" else None
#     )
#
#     data = [trace]
#
#     # Plot and embed in ipython notebook!
#     iplot(data)
Variables
color_palette_3
Functions
plotIntervals
def plotIntervals(
    ranges,
    ticklabels=None,
    invert=False,
    classes=None
)
View Source
def plotIntervals(ranges, ticklabels=None, invert=False, classes=None):
    # Figure Parameters
    fig = plt.figure()
    ax = fig.add_subplot(111)
    out = plot_relevance_bars(ax, ranges, ticklabels=ticklabels, classes=classes)
    fig.autofmt_xdate()
    # Invert the xaxis for cases in which the comparison with other tools
    if invert:
        plt.gca().invert_xaxis()
    return fig
plot_dendrogram_and_intervals
def plot_dendrogram_and_intervals(
    intervals,
    linkage,
    figsize=(13, 7),
    ticklabels=None,
    classes=None,
    **kwargs
)
View Source
def plot_dendrogram_and_intervals(
    intervals, linkage, figsize=(13, 7), ticklabels=None, classes=None, **kwargs
):
    fig, (ax2, ax) = plt.subplots(2, 1, figsize=figsize)
    # Top dendrogram plot
    d = dendrogram(
        linkage,
        color_threshold=0,
        leaf_rotation=0.0,  # rotates the x axis labels
        leaf_font_size=12.0,  # font size for the x axis labels
        ax=ax2,
    )
    # Get index determined through linkage method and dendrogram
    rearranged_index = d["leaves"]
    ranges = intervals[rearranged_index]
    if ticklabels is None:
        ticks = np.array(rearranged_index)
        ticks += 1  # Index starting at 1
    else:
        ticks = list(ticklabels[rearranged_index])
        for i in range(len(intervals)):
            ticks[i] += " - {}".format(rearranged_index[i] + 1)
    plot_relevance_bars(
        ax,
        ranges,
        ticklabels=ticks,
        classes=classes[rearranged_index] if classes is not None else None,
        numbering=False,
        **kwargs,
    )
    fig.subplots_adjust(hspace=0)
    ax.margins(x=0)
    ax2.set_xticks([])
    ax2.margins(x=0)
    plt.tight_layout()
    return fig
plot_intervals
def plot_intervals(
    model,
    ticklabels=None
)
Plot the relevance intervals.
Parameters
model : FRI model Needs to be fitted before. ticklabels : list of str, optional Strs for ticklabels on x-axis (features)
View Source
def plot_intervals(model, ticklabels=None):
    """Plot the relevance intervals.
    Parameters
    ----------
    model : FRI model
        Needs to be fitted before.
    ticklabels : list of str, optional
        Strs for ticklabels on x-axis (features)
    """
    if model.interval_ is not None:
        plotIntervals(
            model.interval_, ticklabels=ticklabels, classes=model.relevance_classes_
        )
    else:
        print("Intervals not computed. Try running fit() function first.")
plot_lupi_intervals
def plot_lupi_intervals(
    model,
    ticklabels=None,
    lupi_ticklabels=None
)
Plot the relevance intervals.
Parameters
model : FRI model Needs to be fitted before. ticklabels : list of str, optional Strs for ticklabels on x-axis (features) lupi_ticklabels : list of str, optional Strs for lupi ticklabels on x-axis (lupi features)
View Source
def plot_lupi_intervals(model, ticklabels=None, lupi_ticklabels=None):
    """Plot the relevance intervals.
    Parameters
    ----------
    model : FRI model
        Needs to be fitted before.
    ticklabels : list of str, optional
        Strs for ticklabels on x-axis (features)
    lupi_ticklabels : list of str, optional
        Strs for lupi ticklabels on x-axis (lupi features)
    """
    n_features = model.interval_.shape[0] - model.lupi_features_
    data_interval_ = model.interval_[0:n_features, :]
    lupi_interval_ = model.interval_[n_features:, :]
    data_relevance_classes_ = model.relevance_classes_[0:n_features]
    lupi_relevance_classes_ = model.relevance_classes_[n_features:]
    if model.interval_ is not None:
        plotIntervals(
            data_interval_, ticklabels=ticklabels, classes=data_relevance_classes_
        )
        plotIntervals(
            lupi_interval_, ticklabels=lupi_ticklabels, classes=lupi_relevance_classes_
        )
    else:
        print("Intervals not computed. Try running fit() function first.")
plot_relevance_bars
def plot_relevance_bars(
    ax,
    ranges,
    ticklabels=None,
    classes=None,
    numbering=True,
    tick_rotation=30
)
Parameters
ax: axis which the bars get drawn on ranges: the 2d array of floating values determining the lower and upper bounds of the bars ticklabels: (optional) labels for each feature classes: (optional) relevance class for each feature, determines color numbering: bool Add feature index when using ticklabels tick_rotation: int Amonut of rotation of ticklabels for easier readability.
View Source
def plot_relevance_bars(
    ax, ranges, ticklabels=None, classes=None, numbering=True, tick_rotation=30
):
    """
    Parameters
    ----------
    ax:
        axis which the bars get drawn on
    ranges:
        the 2d array of floating values determining the lower and upper bounds of the bars
    ticklabels: (optional)
        labels for each feature
    classes: (optional)
        relevance class for each feature, determines color
    numbering: bool
        Add feature index when using ticklabels
    tick_rotation:  int
        Amonut of rotation of ticklabels for easier readability.
    """
    N = len(ranges)
    # Ticklabels
    if ticklabels is None:
        ticks = np.arange(N) + 1
    else:
        ticks = list(ticklabels)
        if numbering:
            for i in range(N):
                ticks[i] += " - {}".format(i + 1)
    # Interval sizes
    ind = np.arange(N) + 1
    width = 0.6
    upper_vals = ranges[:, 1]
    lower_vals = ranges[:, 0]
    height = upper_vals - lower_vals
    # Minimal height to make very small intervals visible
    height[height < 0.001] = 0.001
    # Bar colors
    if classes is None:
        new_classes = np.zeros(N).astype(int)
        color = [color_palette_3[c.astype(int)] for c in new_classes]
    else:
        color = [color_palette_3[c.astype(int)] for c in classes]
    # Plot the bars
    bars = ax.bar(
        ind,
        height,
        width,
        bottom=lower_vals,
        tick_label=ticks,
        align="center",
        edgecolor=["black"] * N,
        linewidth=1.3,
        color=color,
    )
    ax.set_xticklabels(ticks)
    if ticklabels is not None:
        ax.set_xticklabels(ax.get_xticklabels(), rotation=tick_rotation, ha="right")
    # ax.tick_params(rotation="auto")
    # Limit the y range to 0,1 or 0,L1
    ax.set_ylim([0, max(ranges[:, 1]) * 1.1])
    ax.set_ylabel("relevance")
    ax.set_xlabel("feature")
    if classes is not None:
        relevance_classes = ["Irrelevant", "Weakly relevant", "Strongly relevant"]
        patches = []
        for i, rc in enumerate(relevance_classes):
            patch = mpatches.Patch(color=color_palette_3[i], label=rc)
            patches.append(patch)
        ax.legend(handles=patches)
    return bars