[email protected]@has_one_axis function

Article directory

    • original address
    • has_one_axis function
    • d2l.plot function
      • demos

Original URL

  • 2.4. Calculus – Dive into Deep Learning 1.0.0-beta0 documentation (d2l.ai)

has_one_axis function

  • has_one_axis(), the original text is to define it inside the plot, but for the convenience of analysis, it is extracted

  • # If X has an axis, output True
    def has_one_axis(X):
        #Use the hasattr() method to improve the robustness of python's oop program
        # First judge whether X is an ndarray object (or whether it has the ndim attribute), this judgment can improve the rubostness of the program, if there is this attribute, then judge whether the value is 1
        #Or judge whether X is a list object
        """
        isinstance(X, list) and not hasattr(X[0], "__len__")
        This is a conditional expression that checks whether variable X is of type list and its first element is of non-sequence type (i.e. of single value type).
        If X is uniaxial, it means that X corresponds to the (function) of multiple curves or multiple curves will share the same independent variable input
    
        Specifically, this conditional expression consists of two parts, joined by the and operator:
    
        isinstance(X, list): Checks whether variable X is of type list. If so, return True; otherwise, return False.
        not hasattr(X[0], "len"): Check whether the first element of the variable X is a sequence type (that is, whether it has the __len__ attribute). Returns True if not; otherwise returns False. Because of the not operator, the entire expression evaluates to True if the first element is not a sequence type; otherwise it evaluates to False.
        So what this conditional expression does is to check if the variable X is of type list and the first element of it is of non-serial type. If this condition is met, then X is a list containing a single value; otherwise, X is a multidimensional list or array.
    
        Assume the following two variables:
        X = [1, 2, 3, 4, 5]
        Y = [[1, 2], [3, 4], [5, 6]]
        For the variable X, it is a list containing a single value and thus meets both conditions in the conditional expression, that X is a list type and its first element (i.e. 1) is a non-serial type. Therefore, the conditional expression evaluates to True.
        For variable Y, it is a two-dimensional list, so it does not meet the second condition in the conditional expression, that is, the first element in it (ie [1, 2]) is a sequence type. Therefore, the result of the conditional expression is False.
        Therefore, this conditional expression can be used to check whether the variable X is a list containing a single value to uniformly handle data types in plotting functions.
    
        However, if it is judged that X is a list type, it can also be packaged as ndarray, and then judge the ndim attribute np.array(X).ndim==1
    
        """
        ndarray_1dim=hasattr(X, "ndim") and X.ndim == 1
        list_1dim= isinstance(X, list) and X==[] or not hasattr(X[0], "__len__")
        #If X is a list, check whether X contains at least one element to prevent X[0] from accessing out of bounds and reporting an error
        #If X==[], it is uniaxial, otherwise X has at least 1 element (it is generally considered that the elements in X are of the same type), then check whether the elements in X are sequence types, if not, also indicate X is uniaxial
        # print(ndarray_1dim, list_1dim)
        return ndarray_1dim or list_1dim
    
    

d2l.plot function

  • This function calls the above has_one_axis()
  • zip()|Built-in Functions – Python documentation
#@save
def plot(X, Y=None, xlabel=None, ylabel=None, legend=[], xlim=None,
         ylim=None, xscale='linear', yscale='linear',
         fmts=('-', 'm--', 'g-.', 'r:'), figsize=(3.5, 2.5), axes=None):
    """Plot data points"""
    """This is a function written in Python to draw data points. This function can accept multiple sets of data as input, draw them into graphics, and can set various parameters, such as axis labels, legends, coordinates Axis ranges and more.

    Specifically, the parameters of this function include:

    X: a list or array, as the abscissa
    Y: A list or an array, with X as the abscissa (array of independent variables) calculates several groups of function value vectors according to several functions. If Y is not provided, X is used as the y-axis coordinate value by default.
    xlabel: The label of the x-axis.
    ylabel: the label of the y-axis.
    legend: A list containing strings for legend labels. The default value is [].
    xlim: A tuple containing the minimum and maximum values of the x-axis range.
    ylim: A tuple containing the min and max values of the y-axis range.
    xscale: The scaling type of the x-axis. The default is 'linear'.
    yscale: The scaling type of the y-axis. The default is 'linear'.
    fmts: A tuple containing the style of the line. The default is ('-', 'm--', 'g-.', 'r:'). This will allow 4 curves to have different styles from each other
    figsize: A tuple containing the width and height of the figure. The default is (3.5, 2.5).
    axes: A matplotlib.axes.Axes object representing the coordinate system of the plot. If not provided, the current coordinate system is used by default.
    The implementation process of this function mainly includes the following steps:

    Sets the size of the graph.
    Extract or create a coordinate system object to use.
    Check the format of the input data and unify them into lists.
    Clear the coordinate system, and plot each set of data points.
    Set the label, range, zoom type, and legend for the axes.
    This function can conveniently plot multiple sets of data points, and the style and layout of the graph can be adjusted by modifying the parameters. """
  

    set_figsize(figsize)
    axes = axes if axes else d2l.plt.gca()
    #gca: Get the current Axes.


    if has_one_axis(X):
        # Process a uniaxial vector into a two-dimensional matrix (two-axis tensor)
        X = [X]
    #So far, it can be guaranteed that X has at least 2 axes
    if Y is None:
        # print(len(X))
        # The dimension of the outermost axis of X (if X is a uniaxial vector, it will be wrapped as a matrix by the previous if, and the outermost axis is 1)
        # The following processing will make the curve of the curve stack be n=len(X) overlapping straight lines
        # For example [[]]*3=[[],[],[]]
        X, Y = [[]] * len(X), X
        # print(X,'\\
',Y)
        #If Y is None, this process can ensure that both X and Y have at least 2 axes, and X and Y have the same number of elements
    elif has_one_axis(Y):
        # If Y is not None, then judge whether it is a single axis, if so, also wrap it as a two axis
        Y = [Y]
    # So far, it can be guaranteed that both X and Y have 2 axes (if the input X and Y itself do not exceed 2 axes)
    #Consider, if you draw multiple (n) curves (n=len(Y)>1), these curves may share X, then X may only contain r=1 elements (len(X)=r )
    #If r>1 element will not report an error, it will appear in a loop
    #In order to unify the loop code segment, consider using X=X*len(Y) to manually augment X when len(X)!=len(Y)
    if len(X) != len(Y):
        # Note: If it can be run, then X is a list at this time, and X.shape(r,m), m is the length of the unique sequence element in X, for example, X=[[1,2,3]], m=3, but we don't care much about m
        # Multiplying the list by a constant k will result in a new list with k times the number of elements, such as [[1,2,3]]*3=[[1,2,3],[1,2,3], [1,2,3]] and [np.array([1,2])]*3=[array([1, 2]), array([1, 2]), array([1, 2] ])]
        #But [1,2]*3, np.array([1,2])*3=([1, 2, 1, 2, 1, 2], array([3, 6]))
        #The goal of this part is not to make len(X)=len(Y), but to make len(X)>=len(Y), so as to ensure that the curve of Y will not be drawn completely because of the insufficient number of elements of X
        X = X * len(Y)
    #However, the elements in Y should not exceed 4 here, because the style of fmts is only given to 4. When using zip (X, Y, fmts) to combine, the elements generated depend on X, Y, and the least number of fmts elements one
    #If you want to exceed 4, you need to add a line: fmts*len(Y)
    fmts=fmts*len(Y)
    #Here, unify the input data into the same format (shape), so that each group of data can be extracted by loop to draw multiple curves
    # Clear the coordinate system: Clear the Axes.
    axes. cla()
    # draw each curve
    for x, y, fmt in zip(X, Y, fmts):
        # print(x,y,fmt)
        if len(x):
            axes.plot(x, y, fmt)
        else:
            axes.plot(y, fmt)
    set_axes(axes, xlabel, ylabel, xlim, ylim, xscale, yscale, legend)

demos

  • import numpy as np
    
    # Generate 100 random numbers evenly distributed between [0, 1] as the x-axis coordinate value
    x = np.linspace(0, 1, 10)
    
    # Generate 3 sets of random numbers as y-axis coordinate values
    y1 = np.random.randn(10)
    y2 = np.random.randn(10)
    y3 = np.random.randn(10)
    y4 = np.random.randn(10)
    y5 = np.random.randn(10)
    legend=["y" + str(i) for i in range(5)]
    
    
  • # Call the plot function to draw graphics
    plot(x, [y1, y2, y3,y4,y5], xlabel='x', ylabel='y', legend=legend)
    
  • #Use the x vector as input (the two curves share the same argument)
    plot(X=x, Y=Y, xlabel='x', ylabel='f(x)', legend=['f(x)', 'Tangent line (x=1) '])
    
  • #Since Y=None, Y is processed as X, and X is processed as a two-axis matrix with the number of elements size=0. In this case, the independent variable arrays of the two curves are range(len(Y [0]))
    plot(X=X, Y=None, xlabel='x', ylabel='f(x)', legend=['f(x)', 'Tangent line (x=1) '])
    
      • E=np.array([[],[]])
        E, E. shape, E. size
        #(array([], shape=(2, 0), dtype=float64), (2, 0), 0)
        
  • #The independent variable X is a matrix, which contains two lines, and the plot will use these two lines as two independent variable input arrays (intervals), so the two curves drawn will occupy different intervals
    plot(X=X, Y=Y, xlabel='x', ylabel='f(x)', legend=['f(x)', 'Tangent line (x=1) '])