Skip to content

Module fri.plot

View Source
# matplotlib.use("TkAgg")

import matplotlib.cm as cm

import matplotlib.patches as mpatches

import matplotlib.pyplot as plt

import numpy as np

from scipy.cluster.hierarchy import dendrogram

# Get a color for each relevance type

color_palette_3 = cm.Set1([0, 1, 2], alpha=0.8)

def plot_relevance_bars(

    ax, ranges, ticklabels=None, classes=None, numbering=True, tick_rotation=30

):

    """

    Parameters

    ----------

    ax:

        axis which the bars get drawn on

    ranges:

        the 2d array of floating values determining the lower and upper bounds of the bars

    ticklabels: (optional)

        labels for each feature

    classes: (optional)

        relevance class for each feature, determines color

    numbering: bool

        Add feature index when using ticklabels

    tick_rotation:  int

        Amonut of rotation of ticklabels for easier readability.

    """

    N = len(ranges)

    # Ticklabels

    if ticklabels is None:

        ticks = np.arange(N) + 1

    else:

        ticks = list(ticklabels)

        if numbering:

            for i in range(N):

                ticks[i] += " - {}".format(i + 1)

    # Interval sizes

    ind = np.arange(N) + 1

    width = 0.6

    upper_vals = ranges[:, 1]

    lower_vals = ranges[:, 0]

    height = upper_vals - lower_vals

    # Minimal height to make very small intervals visible

    height[height < 0.001] = 0.001

    # Bar colors

    if classes is None:

        new_classes = np.zeros(N).astype(int)

        color = [color_palette_3[c.astype(int)] for c in new_classes]

    else:

        color = [color_palette_3[c.astype(int)] for c in classes]

    # Plot the bars

    bars = ax.bar(

        ind,

        height,

        width,

        bottom=lower_vals,

        tick_label=ticks,

        align="center",

        edgecolor=["black"] * N,

        linewidth=1.3,

        color=color,

    )

    ax.set_xticklabels(ticks)

    if ticklabels is not None:

        ax.set_xticklabels(ax.get_xticklabels(), rotation=tick_rotation, ha="right")

    # ax.tick_params(rotation="auto")

    # Limit the y range to 0,1 or 0,L1

    ax.set_ylim([0, max(ranges[:, 1]) * 1.1])

    ax.set_ylabel("relevance")

    ax.set_xlabel("feature")

    if classes is not None:

        relevance_classes = ["Irrelevant", "Weakly relevant", "Strongly relevant"]

        patches = []

        for i, rc in enumerate(relevance_classes):

            patch = mpatches.Patch(color=color_palette_3[i], label=rc)

            patches.append(patch)

        ax.legend(handles=patches)

    return bars

def plotIntervals(ranges, ticklabels=None, invert=False, classes=None):

    # Figure Parameters

    fig = plt.figure()

    ax = fig.add_subplot(111)

    out = plot_relevance_bars(ax, ranges, ticklabels=ticklabels, classes=classes)

    fig.autofmt_xdate()

    # Invert the xaxis for cases in which the comparison with other tools

    if invert:

        plt.gca().invert_xaxis()

    return fig

def plot_dendrogram_and_intervals(

    intervals, linkage, figsize=(13, 7), ticklabels=None, classes=None, **kwargs

):

    fig, (ax2, ax) = plt.subplots(2, 1, figsize=figsize)

    # Top dendrogram plot

    d = dendrogram(

        linkage,

        color_threshold=0,

        leaf_rotation=0.0,  # rotates the x axis labels

        leaf_font_size=12.0,  # font size for the x axis labels

        ax=ax2,

    )

    # Get index determined through linkage method and dendrogram

    rearranged_index = d["leaves"]

    ranges = intervals[rearranged_index]

    if ticklabels is None:

        ticks = np.array(rearranged_index)

        ticks += 1  # Index starting at 1

    else:

        ticks = list(ticklabels[rearranged_index])

        for i in range(len(intervals)):

            ticks[i] += " - {}".format(rearranged_index[i] + 1)

    plot_relevance_bars(

        ax,

        ranges,

        ticklabels=ticks,

        classes=classes[rearranged_index] if classes is not None else None,

        numbering=False,

        **kwargs,

    )

    fig.subplots_adjust(hspace=0)

    ax.margins(x=0)

    ax2.set_xticks([])

    ax2.margins(x=0)

    plt.tight_layout()

    return fig

def plot_intervals(model, ticklabels=None):

    """Plot the relevance intervals.

    Parameters

    ----------

    model : FRI model

        Needs to be fitted before.

    ticklabels : list of str, optional

        Strs for ticklabels on x-axis (features)

    """

    if model.interval_ is not None:

        plotIntervals(

            model.interval_, ticklabels=ticklabels, classes=model.relevance_classes_

        )

    else:

        print("Intervals not computed. Try running fit() function first.")

def plot_lupi_intervals(model, ticklabels=None, lupi_ticklabels=None):

    """Plot the relevance intervals.

    Parameters

    ----------

    model : FRI model

        Needs to be fitted before.

    ticklabels : list of str, optional

        Strs for ticklabels on x-axis (features)

    lupi_ticklabels : list of str, optional

        Strs for lupi ticklabels on x-axis (lupi features)

    """

    n_features = model.interval_.shape[0] - model.lupi_features_

    data_interval_ = model.interval_[0:n_features, :]

    lupi_interval_ = model.interval_[n_features:, :]

    data_relevance_classes_ = model.relevance_classes_[0:n_features]

    lupi_relevance_classes_ = model.relevance_classes_[n_features:]

    if model.interval_ is not None:

        plotIntervals(

            data_interval_, ticklabels=ticklabels, classes=data_relevance_classes_

        )

        plotIntervals(

            lupi_interval_, ticklabels=lupi_ticklabels, classes=lupi_relevance_classes_

        )

    else:

        print("Intervals not computed. Try running fit() function first.")

#

# def interactive_scatter_embed(embedding, mode="markers", txt=None):

#     # TODO: extend method

#     import plotly.graph_objs as go

#     from plotly.offline import init_notebook_mode, iplot

#     init_notebook_mode(connected=True)

#     # Create a trace

#     trace = go.Scatter(

#         x=embedding[:, 0],

#         y=embedding[:, 1],

#         mode=mode,

#         text=txt if mode is "text" else None

#     )

#

#     data = [trace]

#

#     # Plot and embed in ipython notebook!

#     iplot(data)

Variables

color_palette_3

Functions

plotIntervals

def plotIntervals(
    ranges,
    ticklabels=None,
    invert=False,
    classes=None
)
View Source
def plotIntervals(ranges, ticklabels=None, invert=False, classes=None):

    # Figure Parameters

    fig = plt.figure()

    ax = fig.add_subplot(111)

    out = plot_relevance_bars(ax, ranges, ticklabels=ticklabels, classes=classes)

    fig.autofmt_xdate()

    # Invert the xaxis for cases in which the comparison with other tools

    if invert:

        plt.gca().invert_xaxis()

    return fig

plot_dendrogram_and_intervals

def plot_dendrogram_and_intervals(
    intervals,
    linkage,
    figsize=(13, 7),
    ticklabels=None,
    classes=None,
    **kwargs
)
View Source
def plot_dendrogram_and_intervals(

    intervals, linkage, figsize=(13, 7), ticklabels=None, classes=None, **kwargs

):

    fig, (ax2, ax) = plt.subplots(2, 1, figsize=figsize)

    # Top dendrogram plot

    d = dendrogram(

        linkage,

        color_threshold=0,

        leaf_rotation=0.0,  # rotates the x axis labels

        leaf_font_size=12.0,  # font size for the x axis labels

        ax=ax2,

    )

    # Get index determined through linkage method and dendrogram

    rearranged_index = d["leaves"]

    ranges = intervals[rearranged_index]

    if ticklabels is None:

        ticks = np.array(rearranged_index)

        ticks += 1  # Index starting at 1

    else:

        ticks = list(ticklabels[rearranged_index])

        for i in range(len(intervals)):

            ticks[i] += " - {}".format(rearranged_index[i] + 1)

    plot_relevance_bars(

        ax,

        ranges,

        ticklabels=ticks,

        classes=classes[rearranged_index] if classes is not None else None,

        numbering=False,

        **kwargs,

    )

    fig.subplots_adjust(hspace=0)

    ax.margins(x=0)

    ax2.set_xticks([])

    ax2.margins(x=0)

    plt.tight_layout()

    return fig

plot_intervals

def plot_intervals(
    model,
    ticklabels=None
)

Plot the relevance intervals.

Parameters

model : FRI model Needs to be fitted before. ticklabels : list of str, optional Strs for ticklabels on x-axis (features)

View Source
def plot_intervals(model, ticklabels=None):

    """Plot the relevance intervals.

    Parameters

    ----------

    model : FRI model

        Needs to be fitted before.

    ticklabels : list of str, optional

        Strs for ticklabels on x-axis (features)

    """

    if model.interval_ is not None:

        plotIntervals(

            model.interval_, ticklabels=ticklabels, classes=model.relevance_classes_

        )

    else:

        print("Intervals not computed. Try running fit() function first.")

plot_lupi_intervals

def plot_lupi_intervals(
    model,
    ticklabels=None,
    lupi_ticklabels=None
)

Plot the relevance intervals.

Parameters

model : FRI model Needs to be fitted before. ticklabels : list of str, optional Strs for ticklabels on x-axis (features) lupi_ticklabels : list of str, optional Strs for lupi ticklabels on x-axis (lupi features)

View Source
def plot_lupi_intervals(model, ticklabels=None, lupi_ticklabels=None):

    """Plot the relevance intervals.

    Parameters

    ----------

    model : FRI model

        Needs to be fitted before.

    ticklabels : list of str, optional

        Strs for ticklabels on x-axis (features)

    lupi_ticklabels : list of str, optional

        Strs for lupi ticklabels on x-axis (lupi features)

    """

    n_features = model.interval_.shape[0] - model.lupi_features_

    data_interval_ = model.interval_[0:n_features, :]

    lupi_interval_ = model.interval_[n_features:, :]

    data_relevance_classes_ = model.relevance_classes_[0:n_features]

    lupi_relevance_classes_ = model.relevance_classes_[n_features:]

    if model.interval_ is not None:

        plotIntervals(

            data_interval_, ticklabels=ticklabels, classes=data_relevance_classes_

        )

        plotIntervals(

            lupi_interval_, ticklabels=lupi_ticklabels, classes=lupi_relevance_classes_

        )

    else:

        print("Intervals not computed. Try running fit() function first.")

plot_relevance_bars

def plot_relevance_bars(
    ax,
    ranges,
    ticklabels=None,
    classes=None,
    numbering=True,
    tick_rotation=30
)

Parameters

ax: axis which the bars get drawn on ranges: the 2d array of floating values determining the lower and upper bounds of the bars ticklabels: (optional) labels for each feature classes: (optional) relevance class for each feature, determines color numbering: bool Add feature index when using ticklabels tick_rotation: int Amonut of rotation of ticklabels for easier readability.

View Source
def plot_relevance_bars(

    ax, ranges, ticklabels=None, classes=None, numbering=True, tick_rotation=30

):

    """

    Parameters

    ----------

    ax:

        axis which the bars get drawn on

    ranges:

        the 2d array of floating values determining the lower and upper bounds of the bars

    ticklabels: (optional)

        labels for each feature

    classes: (optional)

        relevance class for each feature, determines color

    numbering: bool

        Add feature index when using ticklabels

    tick_rotation:  int

        Amonut of rotation of ticklabels for easier readability.

    """

    N = len(ranges)

    # Ticklabels

    if ticklabels is None:

        ticks = np.arange(N) + 1

    else:

        ticks = list(ticklabels)

        if numbering:

            for i in range(N):

                ticks[i] += " - {}".format(i + 1)

    # Interval sizes

    ind = np.arange(N) + 1

    width = 0.6

    upper_vals = ranges[:, 1]

    lower_vals = ranges[:, 0]

    height = upper_vals - lower_vals

    # Minimal height to make very small intervals visible

    height[height < 0.001] = 0.001

    # Bar colors

    if classes is None:

        new_classes = np.zeros(N).astype(int)

        color = [color_palette_3[c.astype(int)] for c in new_classes]

    else:

        color = [color_palette_3[c.astype(int)] for c in classes]

    # Plot the bars

    bars = ax.bar(

        ind,

        height,

        width,

        bottom=lower_vals,

        tick_label=ticks,

        align="center",

        edgecolor=["black"] * N,

        linewidth=1.3,

        color=color,

    )

    ax.set_xticklabels(ticks)

    if ticklabels is not None:

        ax.set_xticklabels(ax.get_xticklabels(), rotation=tick_rotation, ha="right")

    # ax.tick_params(rotation="auto")

    # Limit the y range to 0,1 or 0,L1

    ax.set_ylim([0, max(ranges[:, 1]) * 1.1])

    ax.set_ylabel("relevance")

    ax.set_xlabel("feature")

    if classes is not None:

        relevance_classes = ["Irrelevant", "Weakly relevant", "Strongly relevant"]

        patches = []

        for i, rc in enumerate(relevance_classes):

            patch = mpatches.Patch(color=color_palette_3[i], label=rc)

            patches.append(patch)

        ax.legend(handles=patches)

    return bars