Python Treemap Visualization – Plot a Treemap using Python

Greetings! Some links on this site are affiliate links. That means that, if you choose to make a purchase, The Click Reader may earn a small commission at no extra cost to you. We greatly appreciate your support!

Learn to build a Python Treemap Visualization by using Squarify – a library that features a pure Python implementation of the Squarify treemap layout algorithm.

What is a Python Treemap visualization?

A Python Treemap visualization plots hierarchical data using rectangles nested together of varying sizes. The size of each rectangle is in proportion to the amount of data the treemap represents out of the whole. 

Here is an example of two Python Treemaps laid side by side,

Python Treemap Visualization - Plot a Treemap using Python

The ‘Budget 2010’ treemap shows that Alcohol received the most budget amount in 2010 and Food received the least. On the other hand, the ‘Budget 2020’ treemap shows that Food, Diapers and baby stuff and Rent received equal budget amounts that year whereas Health Insurance received the least.

Getting started with Python Treemaps

To start visualizing a treemap in Python, you need to first install the Squarify package/library in Python.

You can install the library by writing the following command in your command line/terminal and executing it,

pip install squarify

Once the installation is complete, you can open your Python IDE or Python Shell and write the following line of code to import the library,

import squarify

If you do not encounter any error when importing the library, you’ve successfully installed and imported the Squarify library in Python. Now, you can move on to visualizing a treemap in Python.

How to create a Treemap in Python?

You can create a basic Python Treemap in just 3 lines of code by using the plot() method from the Squarify library and writing the sizes of the rectangle as shown below,

# Importing the Squarify library
import squarify

# Importing Matplotlib
import matplotlib.pyplot as plt

# Plotting a Python Treemap
squarify.plot(sizes=[40, 30, 5, 25])

# Displaying the plot
plt.show()
Create a treemap in Python - Python Treemap Visualization - Plot a Treemap using Python

Note that every time you re-run the code, you will get different colors for your plot. So, to fix the color map of your plot, please make sure to use the color parameter in the plot() method. You can pass a list of colors in the color parameter as shown below,

# Plotting a Python Treemap
squarify.plot(sizes=[40, 30, 5, 25], color=["Red", "Blue", "Yellow", "Green"]) 

# Displaying the plot
plt.show()
Create a treemap in Python with colors specified - Python Treemap Visualization - Plot a Treemap using Python

You can also assign different labels to the rectangles in your plot by using the label parameter in the plot() method. To do this, simply pass in a list of labels in the label parameter as shown below,

# Plotting a Python Treemap
squarify.plot(sizes=[40, 30, 5, 25], color=["Red", "Blue", "Yellow", "Green"], label=["A", "B", "C", "D"])

# Displaying the plot
plt.show()
Create a treemap in Python with labels specified - Python Treemap Visualization - Plot a Treemap using Python

Oh no! The ‘B’ label is hard to read. You can change the color opacity of a Python treemap by using the alpha parameter in the plot() method from Squarify. Here’s an example,

# Plotting a Python Treemap
squarify.plot(sizes=[40, 30, 5, 25], color=["Red", "Blue", "Yellow", "Green"], label=["A", "B", "C", "D"], alpha=0.7)

# Displaying the plot
plt.show()
Create a treemap in Python with different opacity - Python Treemap Visualization - Plot a Treemap using Python

Also, as you can see the axis values do not provide any real benefit to using a Python treemap. So, let us disable it using Matplotlib since Python treemaps in Squarify are made on top of Matplotlib,

# Importing the Squarify library
import squarify

# Importing Matplotlib
import matplotlib.pyplot as plt

# Plotting a Python Treemap
squarify.plot(sizes=[40, 30, 5, 25], color=["Red", "Blue", "Yellow", "Green"], label=["A", "B", "C", "D"], alpha=0.7)

# Removing the axis values
plt.axis('off')

# Displaying the plot
plt.show()
Remove axis values in Python Treemap - Python Treemap Visualization - Plot a Treemap using Python

That looks much better! Now, let us learn how to plot a treemap using a real-world dataset.

Plotting a Python Treemap using a real-world dataset

To plot a Python treemap using a real-world dataset the steps are fairly simple:

  1. Read the dataset in Python using Pandas.
  2. Groupby the data to get aggregated data values for specific columns.
  3. Plot the treemap in Python

To load in a real-world dataset, we will be using the Seaborn library. If you do not have the library already installed, you can install it by writing the following command in your command line/terminal,

pip install seaborn

Now, let us start off by reading in the Titanic dataset using the seaborn library,

# Importing the Seaborn library in Python
import seaborn as sns

# Importing Matplotlib
import matplotlib.pyplot as plt

# Loading the in-build titanic dataset
titanic = sns.load_dataset('titanic')

# Displaying the first five rows of the dataset
titanic.head()
Titanic dataset - Python Treemap Visualization - Plot a Treemap using Python

This looks like a real-world dataset that we often come across. Now, let’s perform the second step that is groupby the data to get aggregated data values for specific columns.

For this tutorial, let us look at how many people survived the sinking of the Titanic based on their passenger class that is pclass. You can do so by using the groupby() method from a Pandas DataFrame.

# Grouping on the pclass column and finding the sum of survived column
grouped_df = titanic.groupby('pclass')[['survived']].sum()

# Displaying the grouped dataset
grouped_df
Grouped Titanic DataFrame - Python Treemap Visualization - Plot a Treemap using Python

You can see that 136, 87, and 119 passengers in pclass 1, 2, and 3 survived the sinking of the Titanic.

Finally, let us visualize this information using a Python treemap.

# Plotting a Python Treemap
squarify.plot(sizes=grouped_df.values, color=["Red", "Blue", "Yellow"], label=grouped_df.index, alpha=0.7)

# Removing the axis values
plt.axis('off')

# Displaying the plot
plt.show()
Titanic Treemap Visualization - Python Treemap Visualization - Plot a Treemap using Python

Here, our parameters for the plot() method of Squarify is as follows:

  • sizes = The actual values of the column survived after the groupby is performed, that is, grouped_df.values.
  • color = Three colors since there are three values to be plotted in the survived column.
  • label = The index of the groupby DataFrame, that is, grouped_df.index.

Since the pclass 1 has the highest value of survived passengers (136), it occupies the largest size in plotted Python treemap and pclass 2 occupies the least size in the treemap since it has the lowest values of survived passengers (87).

How to visualize multiple Python treemaps using subplots?

Since the treemap is built using Matplotlib, you can plot multiple treemaps using the subplot() method from Matplotlib’s pyplot module.

You will have to set the ax parameter in the plot() method of Squarify to specify which axis we want the treemap to be plotted on. Also, the axis('off') method should be called off of both axes to disable axis values in both subplot axis.

Here’s an example of how you can visualize multiple Python treemaps using subplots,

# Importing the Squarify library
import squarify

# Importing Matplotlib
import matplotlib.pyplot as plt

# Creating 2 Matplotlib subplots
fig, ax = plt.subplots(nrows=1, ncols=2)

# Plotting a Python Treemap
squarify.plot(sizes=[10, 20, 30], color=["Red", "Blue", "Yellow"], label=["A", "B", "C"], alpha=0.7, ax = ax[0])

# Plotting a Python Treemap
squarify.plot(sizes=[30, 20, 10], color=["Red", "Blue", "Yellow"], label=["X", "Y", "Z"], alpha=0.7, ax = ax[1])

# Removing the axis values
ax[0].axis('off')
ax[1].axis('off')

# Displaying the plot
plt.show()
Python treemaps using subplots -Python Treemap Visualization - Plot a Treemap using Python

In Conclusion

You’ve successfully learned how to build a Python treemap from scratch using a list of sizes as well as a real-world dataset. Remember that the Python treemap made using Squarify can be further modified using the Matplotlib library.

Do you have any questions? Please feel free to comment it down below.


Python Treemap Visualization - Plot a Treemap using PythonPython Treemap Visualization - Plot a Treemap using Python

Do you want to learn Python, Data Science, and Machine Learning while getting certified? Here are some best selling Datacamp courses that we recommend you enroll in:

  1. Introduction to Python (Free Course) - 1,000,000+ students already enrolled!
  2. Introduction to Data Science  in Python- 400,000+ students already enrolled!
  3. Introduction to TensorFlow for Deep Learning with Python - 90,000+ students already enrolled!
  4. Data Science and Machine Learning Bootcamp with R - 70,000+ students already enrolled!

Leave a Comment