Remove all elements in each list of a nested list, based on first nested list

Question:

I have a following list:

a = [[0, 1, 0, 1, 1, 1], [23,22, 12, 45, 32, 33],[232, 332, 222, 342, 321, 232]]

I want to remove 0 in a[0] and corresponding values of a[1] and [2], so the result list should be as follows:

d = [[1, 1, 1, 1], [22, 45, 32, 33], [332, 342, 321, 232]]
Asked By: Sher

||

Answers:

You probably want a function something like this (currently untested):

def filter_list(ls: list[list[int]]) -> list[list[int]]:

  keep_list = a[0]

  def filter_out(ls: list[int]) -> list[int]:
    return [value for idx, value in enumerate(ls) if keep_list[idx] != 0]

  return [filter_out(l) for l in ls]

It’s a bit weird that your first list in the list of lists is the controlling one, but essentially you want to iterate through each other list and check to see if, at that point in the list, the controlling list (keep_list) has a 0, and remove that element if so. I put this in an inner function because it’s cleaner to think about it – you could write a nested list comprehension or nested for loops and achieve the same result.

Answered By: Nathaniel Ford

itertools.compress is built for this task. Pair it with a listcomp to operate on each sublist, and force it to eagerly produce each new sublist, and you get what you want with fairly simple code:

from itertools import compress

a = [[0, 1, 0, 1, 1, 1], [23,22, 12, 45, 32, 33],[232, 332, 222, 342, 321, 232]]

new_a = [list(compress(sublst, a[0])) for sublst in a]

print(new_a)

Which outputs:

[[1, 1, 1, 1], [22, 45, 32, 33], [332, 342, 321, 232]]

Try it online!

Each call uses a[0] as the selectors to determine which elements to keep from the data (each new sub-list); when the selector value from a[0] is falsy (0 here), the corresponding element is dropped, when it’s truthy (1 here) it’s retained.

Answered By: ShadowRanger

I like the itertools.compress answer. However, nested lists of integers in Python are almost always better stored in numpy arrays, which offer rich ways to select rows/columns for tasks such as this:

>>> import numpy as np
>>> a = [[0, 1, 0, 1, 1, 1], [23,22, 12, 45, 32, 33],[232, 332, 222, 342, 321, 232]]
>>> a = np.array(a)
>>> a[:, a[0]!=0]
array([[  1,   1,   1,   1],
       [ 22,  45,  32,  33],
       [332, 342, 321, 232]])

Numpy uses views instead of copies when possible, so it’s often more memory efficient too.

Going back to Python lists is easy:

>>> a[:, a[0]!=0].tolist()
[[1, 1, 1, 1], [22, 45, 32, 33], [332, 342, 321, 232]]
Answered By: wim
Categories: questions Tags: ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.