3D Scatter Plot

Using the iris dataset from scikit-learn, create a 3D scatter plot showcasing the relationship between sepal length, sepal width, and petal length. Differentiate each species with a unique color.

Example Output:

Use the Axes3D class from mpl_toolkits.mplot3d to create a 3D axis.

import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from mpl_toolkits.mplot3d import Axes3D

def scatter_3d_iris():
    iris = load_iris()
    colors = ['red', 'green', 'blue']
    
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    for target, color in zip(range(3), colors):
        subset = iris.data[iris.target == target]
        ax.scatter(subset[:, 0], subset[:, 1], subset[:, 2], c=color, label=iris.target_names[target], s=50, alpha=0.6)
    ax.set_xlabel("Sepal Length (cm)")
    ax.set_ylabel("Sepal Width (cm)")
    ax.set_zlabel("Petal Length (cm)")
    ax.set_title("3D Scatter Plot for Iris Dataset")
    ax.legend()
    plt.show()

# Example usage
scatter_3d_iris()

 

© Let’s Data Science

LOGIN

Unlock AI & Data Science treasures. Log in!