Python has consistently ranked among the most popular programming languages. According to the TIOBE Software Index, Python was the most popular language in 2021. In this article, we’ll explore the usage of the filter()
function in Python, which is one of the most important functions in the language.
filter()
is a built-in Python function, meaning it does not require importing additional libraries.
Syntax
The function takes two arguments: a function and an iterable object.
filter(function, iterable)
function
is a function with a single argument. This function is used to filter values.
iterable
is an object that can be iterated over, such as a list, tuple, dictionary, etc. It can also include generator or iterator objects. The filter()
function accepts only one iterable object.
The filter()
function breaks down the provided iterable object into elements and passes each one to the given function, which returns a value (True, False, or something else like a number or string). The filter()
function evaluates the returned value, and if it is "truthy" (not necessarily equal to True
, but considered true), it adds the element to the iterator. If the value is not truthy, the element is excluded. The result is an iterator containing only the elements that returned True
during filtering.
To get elements evaluated as False
, use the itertools.filterfalse()
function.
The filter()
function is more efficient in terms of execution time than a for
loop, which can also be used for filtering. Another advantage is that filter()
returns an iterator, which is a more memory-efficient way of handling data. This was introduced for filter()
in Python 3. In Python 2, the filter()
function returns a list.
Now that we’ve covered the basics, let's look at how filter()
works through various examples.
One of the simplest examples is filtering even numbers.
numbers = [1, 2, 3, 4, 5, 7, 10, 11]
def filter_num(num):
if (num % 2) != 0:
return True
else:
return False
out_filter = filter(filter_num, numbers)
print("Filtered list: ", list(out_filter))
In this case, we pass a custom function (filter_num
) and a list of numbers (numbers
) to filter()
. The result will be:
Filtered list: [1, 3, 5, 7, 11]
Our custom function checks if each number is odd. If there is a non-zero remainder when dividing by 2, the function returns True, meaning the element is added to the resulting iterator. Since filter()
returns an object of type <class 'filter'>
, we need to convert the output to a list to see the result. This example can also be implemented using a lambda
function:
filter(lambda n: n % 2 != 0, numbers)
Input data:
arr1 = ['1', '2', '3', '4', 5, 6, 7]
arr2 = [1, '2', 3, '4', '5', '6', 7]
We write a function to find the intersection:
def intersection(arr1, arr2):
out = list(filter(lambda it: it in arr1, arr2))
return out
The function takes two arrays as input and checks them. Using a lambda
function, it identifies the common elements.
Calling the function and displaying the result:
out = intersection(arr1, arr2)
print("Filtered list:", out)
The result:
Filtered list: ['2', '4', 7]
The Python filter()
function can also accept lambda functions. For example, let’s create a palindrome detector:
word = ["cat", "rewire", "level", "book", "stats", "list"]
palindromes = list(filter(lambda word: word == word[::-1], word))
print("Palindromes: ", list(palindromes))
The result:
Palindromes: ['level', 'stats']
The lambda function checks if a word is the same when written in reverse. If it is, the function returns True
.
We import a library for statistical computations and set up a normally distributed sample with a few outliers:
import statistics as st
sample = [10, 8, 10, 8, 2, 7, 9, 3, 34, 9, 5, 9, 25]
We calculate the mean:
mean = st.mean(sample)
Mean: 10.69
In normally distributed samples, outliers are often defined as values that deviate from the mean by more than two standard deviations.
stdev = st.stdev(sample)
low = mean - 2*stdev
high = mean + 2*stdev
Next, we calculate the standard deviation and the upper and lower bounds, then filter the sample:
clean_sample = list(filter(lambda x: low <= x <= high, sample))
Result:
[10, 8, 10, 8, 2, 7, 9, 3, 9, 5, 9, 25]
Clearly, the value 34 was an outlier. Now, the new mean is 8.75.
If we perform another iteration of this method, the value 25 will also be filtered out, leaving us with:
Sample without outliers: [10, 8, 10, 8, 2, 7, 9, 3, 9, 5, 9]
The new mean is 7.273, which is significantly different from the original.
To understand how filter()
handles None
, let’s look at the following example:
list_ = [0, 1, 'Hello', '', None, [], [1,2,3], 0.1, 0.0, False]
print(list(filter(None, list_)))
If None
is passed as the function in filter()
, it filters out all logically False
elements (i.e., elements that are false by themselves). In this case, the result will be:
[1, 'Hello', [1, 2, 3], 0.1]
Here, elements like 0
, []
, None
, ''
, False
are filtered out because they have a logical value of False
.
The function can also work with more complex data structures. For example, if we have a list of dictionaries and want to iterate through each element in the list, including key-value pairs in those dictionaries.
Let’s take a list of books in a bookstore:
books = [
{"Title": "Angels and Demons", "Author": "Dan Brown", "Price": 9},
{"Title": "Harry Potter and the Philosopher's Stone", "Author": "J.K. Rowling", "Price": 7},
{"Title": "Anna Karenina", "Author": "Leo Tolstoy", "Price": 5},
{"Title": "Dead Souls", "Author": "Nikolai Gogol", "Price": 4}
]
We will filter books by price. We’ll write a function that retrieves all books costing more than 5:
def cost(book):
return book["Price"] > 5
Here, the function simply checks each book’s price and returns True if it meets the condition. To display the book titles, we iterate through the filtered object:
filtered_object = filter(cost, books)
for row in filtered_object:
print(dict(row)["Title"])
Result:
Angels and Demons
Harry Potter and the Philosopher's Stone
Suppose we have the following sample:
sample = [10.1, 8.3, 10.4, 8.8, float("nan"), 7.2, float("nan")]
If we try to compute something like the mean or standard deviation on this sample, we will get nan
(not a number). NaN
values can appear for various reasons, so one option is to remove them from the data.
We use the isnan()
function from the math module, which checks if a value is NaN
:
import math
import statistics as st
sample = [10.1, 8.3, 10.4, 8.8, float("nan"), 7.2, float("nan")]
def searcnan(x):
return not math.isnan(x)
Now, when we call:
st.mean(filter(searcnan, sample))
We get a result of 8.96.
Alternatively, we can simplify this by using the filterfalse()
function, which retains elements where the condition is False
:
from itertools import filterfalse
st.mean(filterfalse(math.isnan, sample))
The result is the same: 8.96.
As we’ve seen, Python’s filter()
function can be used in various ways. We covered some of the main applications, but as you continue to work creatively, you’ll likely discover many other ways to use this powerful function.
Check out our app platform to deploy Python applications, such as Celery, Django, FastAPI and Flask.