Learn to build a Python Treemap Visualization by using Squarify – a library that features a pure Python implementation of the Squarify treemap layout algorithm.
What is a Python Treemap visualization?
A Python Treemap visualization plots hierarchical data using rectangles nested together of varying sizes. The size of each rectangle is in proportion to the amount of data the treemap represents out of the whole.
Here is an example of two Python Treemaps laid side by side,
The ‘Budget 2010’ treemap shows that Alcohol received the most budget amount in 2010 and Food received the least. On the other hand, the ‘Budget 2020’ treemap shows that Food, Diapers and baby stuff and Rent received equal budget amounts that year whereas Health Insurance received the least.
Getting started with Python Treemaps
To start visualizing a treemap in Python, you need to first install the Squarify package/library in Python.
You can install the library by writing the following command in your command line/terminal and executing it,
pip install squarify
Once the installation is complete, you can open your Python IDE or Python Shell and write the following line of code to import the library,
import squarify
If you do not encounter any error when importing the library, you’ve successfully installed and imported the Squarify library in Python. Now, you can move on to visualizing a treemap in Python.
How to create a Treemap in Python?
You can create a basic Python Treemap in just 3 lines of code by using the plot()
method from the Squarify library and writing the sizes of the rectangle as shown below,
# Importing the Squarify library import squarify # Importing Matplotlib import matplotlib.pyplot as plt # Plotting a Python Treemap squarify.plot(sizes=[40, 30, 5, 25]) # Displaying the plot plt.show()
Note that every time you re-run the code, you will get different colors for your plot. So, to fix the color map of your plot, please make sure to use the color
parameter in the plot()
method. You can pass a list of colors in the color
parameter as shown below,
# Plotting a Python Treemap squarify.plot(sizes=[40, 30, 5, 25], color=["Red", "Blue", "Yellow", "Green"]) # Displaying the plot plt.show()
You can also assign different labels to the rectangles in your plot by using the label
parameter in the plot()
method. To do this, simply pass in a list of labels in the label
parameter as shown below,
# Plotting a Python Treemap squarify.plot(sizes=[40, 30, 5, 25], color=["Red", "Blue", "Yellow", "Green"], label=["A", "B", "C", "D"]) # Displaying the plot plt.show()
Oh no! The ‘B’ label is hard to read. You can change the color opacity of a Python treemap by using the alpha
parameter in the plot()
method from Squarify. Here’s an example,
# Plotting a Python Treemap squarify.plot(sizes=[40, 30, 5, 25], color=["Red", "Blue", "Yellow", "Green"], label=["A", "B", "C", "D"], alpha=0.7) # Displaying the plot plt.show()
Also, as you can see the axis values do not provide any real benefit to using a Python treemap. So, let us disable it using Matplotlib since Python treemaps in Squarify are made on top of Matplotlib,
# Importing the Squarify library import squarify # Importing Matplotlib import matplotlib.pyplot as plt # Plotting a Python Treemap squarify.plot(sizes=[40, 30, 5, 25], color=["Red", "Blue", "Yellow", "Green"], label=["A", "B", "C", "D"], alpha=0.7) # Removing the axis values plt.axis('off') # Displaying the plot plt.show()
That looks much better! Now, let us learn how to plot a treemap using a real-world dataset.
Plotting a Python Treemap using a real-world dataset
To plot a Python treemap using a real-world dataset the steps are fairly simple:
- Read the dataset in Python using Pandas.
- Groupby the data to get aggregated data values for specific columns.
- Plot the treemap in Python
To load in a real-world dataset, we will be using the Seaborn library. If you do not have the library already installed, you can install it by writing the following command in your command line/terminal,
pip install seaborn
Now, let us start off by reading in the Titanic dataset using the seaborn
library,
# Importing the Seaborn library in Python import seaborn as sns # Importing Matplotlib import matplotlib.pyplot as plt # Loading the in-build titanic dataset titanic = sns.load_dataset('titanic') # Displaying the first five rows of the dataset titanic.head()
This looks like a real-world dataset that we often come across. Now, let’s perform the second step that is groupby the data to get aggregated data values for specific columns.
For this tutorial, let us look at how many people survived the sinking of the Titanic based on their passenger class that is pclass
. You can do so by using the groupby() method from a Pandas DataFrame.
# Grouping on the pclass column and finding the sum of survived column grouped_df = titanic.groupby('pclass')[['survived']].sum() # Displaying the grouped dataset grouped_df
You can see that 136, 87, and 119 passengers in pclass
1, 2, and 3 survived the sinking of the Titanic.
Finally, let us visualize this information using a Python treemap.
# Plotting a Python Treemap squarify.plot(sizes=grouped_df.values, color=["Red", "Blue", "Yellow"], label=grouped_df.index, alpha=0.7) # Removing the axis values plt.axis('off') # Displaying the plot plt.show()
Here, our parameters for the plot()
method of Squarify is as follows:
sizes
= The actual values of the columnsurvived
after the groupby is performed, that is,grouped_df.values
.color
= Three colors since there are three values to be plotted in thesurvived
column.label
= The index of the groupby DataFrame, that is,grouped_df.index
.
Since the pclass
1 has the highest value of survived
passengers (136), it occupies the largest size in plotted Python treemap and pclass 2
occupies the least size in the treemap since it has the lowest values of survived
passengers (87).
How to visualize multiple Python treemaps using subplots?
Since the treemap is built using Matplotlib, you can plot multiple treemaps using the subplot()
method from Matplotlib’s pyplot module.
You will have to set the ax
parameter in the plot()
method of Squarify to specify which axis we want the treemap to be plotted on. Also, the axis('off')
method should be called off of both axes to disable axis values in both subplot axis.
Here’s an example of how you can visualize multiple Python treemaps using subplots,
# Importing the Squarify library import squarify # Importing Matplotlib import matplotlib.pyplot as plt # Creating 2 Matplotlib subplots fig, ax = plt.subplots(nrows=1, ncols=2) # Plotting a Python Treemap squarify.plot(sizes=[10, 20, 30], color=["Red", "Blue", "Yellow"], label=["A", "B", "C"], alpha=0.7, ax = ax[0]) # Plotting a Python Treemap squarify.plot(sizes=[30, 20, 10], color=["Red", "Blue", "Yellow"], label=["X", "Y", "Z"], alpha=0.7, ax = ax[1]) # Removing the axis values ax[0].axis('off') ax[1].axis('off') # Displaying the plot plt.show()
In Conclusion
You’ve successfully learned how to build a Python treemap from scratch using a list of sizes as well as a real-world dataset. Remember that the Python treemap made using Squarify can be further modified using the Matplotlib library.
Do you have any questions? Please feel free to comment it down below.
Do you want to learn Python, Data Science, and Machine Learning while getting certified? Here are some best selling Datacamp courses that we recommend you enroll in:
- Introduction to Python (Free Course) - 1,000,000+ students already enrolled!
- Introduction to Data Science in Python- 400,000+ students already enrolled!
- Introduction to TensorFlow for Deep Learning with Python - 90,000+ students already enrolled!
- Data Science and Machine Learning Bootcamp with R - 70,000+ students already enrolled!