Merge lists in a dataframe column if they share a common value

Question:

What I need:

I have a dataframe where the elements of a column are lists. There are no duplications of elements in a list. For example, a dataframe like the following:

import pandas as pd

>>d = {'col1': [[1, 2, 4, 8], [15, 16, 17], [18, 3], [2, 19], [10, 4]]}
>>df = pd.DataFrame(data=d)

           col1
0  [1, 2, 4, 8]
1  [15, 16, 17]
2       [18, 3]
3       [2, 19]
4       [10, 4]

I would like to obtain a dataframe where, if at least a number contained in a list at row i is also contained in a list at row j, then the two list are merged (without duplication). But the values could also be shared by more than two lists, in that case I want all lists that share at least a value to be merged.

                   col1
0  [1, 2, 4, 8, 19, 10]
1          [15, 16, 17]
2               [18, 3]

The order of the rows of the output dataframe, nor the values inside a list is important.


What I tried:

I have found this answer, that shows how to tell if at least one item in list is contained in another list, e.g.

>>not set([1, 2, 4, 8]).isdisjoint([2, 19])
True

Returns True, since 2 is contained in both lists.

I have also found this useful answer that shows how to compare each row of a dataframe with each other. The answer applies a custom function to each row of the dataframe using a lambda.

df.apply(lambda row: func(row['col1']), axis=1)

However I’m not sure how to put this two things together, how to create the func method. Also I don’t know if this approach is even feasible since the resulting rows will probably be less than the ones of the original dataframe.

Thanks!

Asked By: ClaudiaR

||

Answers:

This is not straightforward. Merging lists has many pitfalls.

One solid approach is to use a specialized library, for example networkx to use a graph approach. You can generate successive edges and find the connected components.

Here is your graph:

networkx graph list merging

You can thus:

  • generate successive edges with add_edges_from
  • find the connected_components
  • craft a dictionary and map the first item of each list
  • groupby and merge the lists (you could use the connected components directly but I’m giving a pandas solution in case you have more columns to handle)
import networkx as nx

G = nx.Graph()
for l in df['col1']:
    G.add_edges_from(zip(l, l[1:]))

groups = {k:v for v,l in enumerate(nx.connected_components(G)) for k in l}
# {1: 0, 2: 0, 4: 0, 8: 0, 10: 0, 19: 0, 16: 1, 17: 1, 15: 1, 18: 2, 3: 2}

out = (df.groupby(df['col1'].str[0].map(groups), as_index=False)
         .agg(lambda x: sorted(set().union(*x)))
       )

output:

                   col1
0  [1, 2, 4, 8, 10, 19]
1          [15, 16, 17]
2               [3, 18]
Answered By: mozway

You can use networkx and graphs for that:

import networkx as nx

G = nx.Graph([edge for nodes in df['col1'] for edge in zip(nodes, nodes[1:])])
result = pd.Series(nx.connected_components(G))

This is basically treating every number as a node, and whenever two number are in the same list then you connect them. Finally you find the connected components.

Output:

0  {1, 2, 4, 8, 10, 19}
1          {16, 17, 15}
2               {18, 3}
Answered By: Riccardo Bucco

Seems more like a Python problem than pandas one, so here’s one attempt that checks every after list, merges (and removes) if intersecting:

vals = d["col1"]

# while there are at least 1 more list after to process...
i = 0
while i < len(vals) - 1:
    current = set(vals[i])

    # for the next lists...
    j = i + 1
    while j < len(vals):
        # any intersection?
        # then update the current and delete the other
        other = vals[j]
        if current.intersection(other):
            current.update(other)
            del vals[j]
        else:
            # no intersection, so keep going for next lists
            j += 1

    # put back the updated current back, and move on
    vals[i] = current
    i += 1

at the end, vals is

In [108]: vals
Out[108]: [{1, 2, 4, 8, 10, 19}, {15, 16, 17}, {3, 18}]

In [109]: pd.Series(map(list, vals))
Out[109]:
0    [1, 2, 19, 4, 8, 10]
1            [16, 17, 15]
2                 [18, 3]
dtype: object

if you don’t want vals modified, can chain .copy() for it.

Answered By: Mustafa Aydın

To add on mozway‘s answer. It wasn’t clear from the question, but I also had rows with single-valued lists. This values aren’t clearly added to the graph when calling add_edges_from(zip(l, l[1:]), since l[1:] is empty. I solved it adding a singular node to the graph when encountering emtpy l[1:] lists. I leave the solution in case anyone needs it.

import networkx as nx
import pandas as pd

d = {'col1': [[1, 2, 4, 8], [15, 16, 17], [18, 3], [2, 19], [10, 4], [9]]}
df= pd.DataFrame(data=d) 

G = nx.Graph()
for l in df['col1']:
    if len(l[1:]) == 0:
      G.add_node(l[0])
    else:
      G.add_edges_from(zip(l, l[1:]))

groups = {k: v for v, l in enumerate(nx.connected_components(G)) for k in l}

out= (df.groupby(df['col1'].str[0].map(groups), as_index=False)
    .agg(lambda x: sorted(set().union(*x))))

Result:

                   col1
0  [1, 2, 4, 8, 10, 19]
1          [15, 16, 17]
2               [3, 18]
3                   [9]
Answered By: ClaudiaR
Categories: questions Tags: ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.