Source code for transparentai.datasets.variable.variable_plots

import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import numpy as np
import pandas as pd

from transparentai import utils
from transparentai import plots
from .variable import describe


DEFAULT_COLOR = '#3498db'


[docs]def plot_table_describe(ax, cell_text): """Insert a table in a matplotlib graphic using an axis. Parameters ---------- ax: plt.axes.Axes axe where to add the plot cell_text: list(list) The texts to place into the table cells. """ table = ax.table(cellText=cell_text, cellLoc='left', loc='center') table[(0, 0)].get_text().set_color('green') table[(1, 0)].get_text().set_color('red') for i in range(len(cell_text)): table[(i, 0)].get_text().set_weight('bold') table.set_fontsize(24) table.scale(1, 3) ax.axis('off')
[docs]def plot_number_var(ax, arr, color=DEFAULT_COLOR, label=None, alpha=1.): """Plots an histogram into an matplotlib axe. Parameters ---------- ax: plt.axes.Axes axe where to add the plot arr: array like Array of number values color: str (default DEFAULT_COLOR) color of the plot label: str (default None) label of the plot alpha: float (default 1.) opacity Raises ------ TypeError: arr is not an array like TypeError: arr is not a number array """ if not utils.is_array_like(arr): raise TypeError('arr is not an array like') if utils.find_dtype(arr) != 'number': raise TypeError('arr is not a number array') ax.hist(arr, bins=50, color=color, label=label, alpha=alpha)
[docs]def plot_datetime_var(ax, arr, color=DEFAULT_COLOR, label=None, alpha=1.): """Plots a line plot into an matplotlib axe. Parameters ---------- ax: plt.axes.Axes axe where to add the plot arr: array like Array of datetime values color: str (default DEFAULT_COLOR) color of the plot label: str (default None) label of the plot alpha: float (default 1.) opacity Raises ------ TypeError: arr is not an array like TypeError: arr is not a datetime array """ if not utils.is_array_like(arr): raise TypeError('arr is not an array like') if utils.find_dtype(arr) != 'datetime': raise TypeError('arr is not a datetime array') arr = pd.to_datetime(arr, errors='coerce') date_min = arr.min() date_max = arr.max() gap = (date_max - date_min).days if gap > 1500: arr = arr.dt.year.astype(str) elif gap > 100: arr = arr.dt.strftime('%Y-%m') elif gap > 5: arr = arr.dt.strftime('%Y-%m-%d') else: arr = arr.dt.strftime('%Y-%m-%d-%r') v_c = arr.value_counts().sort_index() dates = mdates.num2date(mdates.datestr2num(v_c.index)) y = v_c.values ax.plot(dates, y, color=color, label=label) ax.fill_between(dates, 0, y, color=color, alpha=alpha)
[docs]def plot_object_var(ax, arr, top=10, color=DEFAULT_COLOR, label=None, alpha=1.): """Plots a bar plot into an matplotlib axe. Parameters ---------- ax: plt.axes.Axes axe where to add the plot arr: array like Array of object values color: str (default DEFAULT_COLOR) color of the plot label: str (default None) label of the plot alpha: float (default 1.) opacity Raises ------ TypeError: arr is not an array like TypeError: arr is not a object array """ if not utils.is_array_like(arr): raise TypeError('arr is not an array like') if utils.find_dtype(arr) != 'object': raise TypeError('arr is not an object array') if type(arr) in [list, np.ndarray]: arr = pd.Series(arr) v_c = arr.value_counts().sort_values(ascending=False) v_c = v_c if len(v_c) <= top else v_c[:top] x, y = v_c.index, v_c.values bar = ax.bar(x, y, color=color, label=label, alpha=alpha)
[docs]def plot_variable(arr, legend=None, colors=None, xlog=False, ylog=False, **kwargs): """Plots a graph with two parts given an array. First part is the plot custom plot depending on the array dtype. Second part is the describe statistics table. First plot is: - Histogram if dtype is number (using plot_number_var) - Line plot if dtype is datetime (using plot_datetime_var) - Bar plot if dtype is object (using plot_object_var) If legend array is set then automaticly plots differents values. Parameters ---------- arr: array like Array of values to plots legend: array like (default None) Array of values of legend (same length than arr) colors: list (default None) Array of colors, used if legend is set xlog: bool (default False) Scale xaxis in log scale ylog: bool (default False) Scale yaxis in log scale Raises ------ TypeError: arr is not an array like TypeError: legend is not an array like ValueError: arr and legend have not the same length """ if not utils.is_array_like(arr): raise TypeError('arr is not an array like') if (legend is not None) & (not utils.is_array_like(legend)): raise TypeError('legend is not an array like') if legend is not None: if len(arr) != len(legend): raise ValueError('arr and legend have not the same length') name = '' if type(arr) == pd.Series: name = arr.name elif type(arr) == pd.DataFrame: name = arr.columns[0] elif type(arr) in [list, np.ndarray]: arr = pd.Series(arr) if (legend is not None) & (colors is None): colors = ['#3498db', '#e67e22', '#2ecc71', '#f1c40f', '#9b59b6', '#e74c3c'] if legend is not None: legend_name = '' if type(legend) == pd.Series: legend_name = legend.name elif type(legend) == pd.DataFrame: legend_name = legend.columns[0] dtype = utils.find_dtype(arr) desc = describe(arr) # Init figure fig = plt.figure(figsize=(15, 5), constrained_layout=False) gs = fig.add_gridspec(1, 12) # 2 axes : one for the plot, one for the stats ax1 = fig.add_subplot(gs[0, :8]) ax2 = fig.add_subplot(gs[0, 8:]) # format title title = 'Histogram' if dtype == 'number' else 'Plot' title = title if name is None else title + ' of ' + name title = title if legend is None else title + ' by ' + legend_name ax1.set_title(title, loc='center', fontsize=22) # Use plot depending on the dtype # Number : histogram if dtype == 'number': plot_fun = plot_number_var # Datetime : line plot elif dtype == 'datetime': plot_fun = plot_datetime_var fig.autofmt_xdate() # Object : bar plot else: plot_fun = plot_object_var plt.setp(ax1.xaxis.get_majorticklabels(), rotation=18) if legend is None: plot_fun(ax1, arr) else: for i, label in enumerate(set(list(legend))): arr_val = arr[legend == label] if len(arr_val) == 0: continue plot_fun(ax1, arr_val, color=colors[i % len(colors)], label=label, alpha=0.5) # If log is needed if xlog: ax1.set_xscale('log') if ylog: ax1.set_yscale('log') # put legend if it's necessary if legend is not None: ax1.legend(loc=0, frameon=True) # Add describe stats table desc_formated = utils.format_describe_str(desc) plot_table_describe(ax2, desc_formated) # plt.show() return plots.plot_or_figure(fig, **kwargs)