Find how many connected groups of nodes in a given adjacency matrix

Question:

I have a list of lists, each list is a node and contains the edges to other nodes. e.g
[[1, 1, 0], [1, 1, 0], [0, 0, 1]]

The node has a 1 when it refers to its own position, as well as when it has an edge to another node, and a 0 when no edge exists.

This means that node 0 ([1, 1, 0]) is connected to node 1, and node 2 ([0,0,1]) is not connected to any other nodes. Therefore this list of lists can be thought of as an adjacency matrix:

1 1 0 <- node 0
1 1 0 <- node 1
0 0 1 <- node 2

Adding on to this, whether a node is connected with another is transitive, meaning that if node 1 is connected to node 2 and node 2 is connected to node 3, nodes 1 and 3 are also connected (by transitivity).

Taking all this into account, I want to be able to know how many connected groups there are given a matrix. What algorithm should I use, recursive DFS? Can someone provide any hints or pseudocode as to how this problem can be approached?

Asked By: Boa

||

Answers:

There are many approaches to do this. You can use DFS/BFS or disjoint sets to solve this problem. Here are some useful links:

https://www.geeksforgeeks.org/connected-components-in-an-undirected-graph/
https://www.geeksforgeeks.org/find-the-number-of-islands-set-2-using-disjoint-set/

Answered By: P.Gupta

If the input matrix is guaranteed to describe transitive connectivity, it has a peculiar form that allows for an algorithm probing only a subset of the matrix elements. Here is an example implementation in Python:

def count_connected_groups(adj):
    n = len(adj)
    nodes_to_check = set([i for i in range(n)]) # [] not needed in python 3
    count = 0
    while nodes_to_check:
        count += 1
        node = nodes_to_check.pop()
        adjacent = adj[node]
        other_group_members = set()
        for i in nodes_to_check:
            if adjacent[i]:
                other_group_members.add(i)
        nodes_to_check -= other_group_members
    return count


# your example:
adj_0 = [[1, 1, 0], [1, 1, 0], [0, 0, 1]]
# same with tuples and booleans:
adj_1 = ((True, True, False), (True, True, False), (False, False, True))
# another connectivity matrix:
adj_2 = ((1, 1, 1, 0, 0),
         (1, 1, 1, 0, 0),
         (1, 1, 1, 0, 0),
         (0, 0, 0, 1, 1),
         (0, 0, 0, 1, 1))
# and yet another:
adj_3 = ((1, 0, 1, 0, 0),
         (0, 1, 0, 1, 0),
         (1, 0, 1, 0, 0),
         (0, 1, 0, 1, 0),
         (0, 0, 0, 0, 1))
for a in adj_0, adj_1, adj_2, adj_3:
    print(a)
    print(count_connected_groups(a))


# [[1, 1, 0], [1, 1, 0], [0, 0, 1]]
# 2
# ((True, True, False), (True, True, False), (False, False, True))
# 2
# ((1, 1, 1, 0, 0), (1, 1, 1, 0, 0), (1, 1, 1, 0, 0), (0, 0, 0, 1, 1), (0, 0, 0, 1, 1))
# 2
# ((1, 0, 1, 0, 0), (0, 1, 0, 1, 0), (1, 0, 1, 0, 0), (0, 1, 0, 1, 0), (0, 0, 0, 0, 1))
# 3

An optimized version of the same algorithm (less readable, but faster and more easily translatable into other languages) is the following:

def count_connected_groups(adj):
    n = len(adj)
    nodes_to_check = [i for i in range(n)]  # [0, 1, ..., n-1]
    count = 0
    while n:
        count += 1
        n -= 1; node = nodes_to_check[n]
        adjacent = adj[node]
        i = 0
        while i < n:
            other_node = nodes_to_check[i]
            if adjacent[other_node]:
                n -= 1; nodes_to_check[i] = nodes_to_check[n]
            else:
                i += 1
    return count
Answered By: Walter Tross

Solution with Java syntax:

import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.Stack;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

public class ConnectedGroups {

    public static void main(String[] args) {
        List<String> adj0 = Arrays.asList("110", "110", "001");
        List<String> adj1 = Arrays.asList("10000","01000","00100","00010","00001");
        List<String> adj2 = Arrays.asList("11100","11100","11100","00011","00011");
        List<String> adj3 = Arrays.asList("10100","01010","10100","01010","00001");

        for (List<String> related : Arrays.asList(adj0, adj1, adj2, adj3)) {
            System.out.println(related);
            System.out.println(count_connected_groups(related));
        }
    }

    private static int count_connected_groups(List<String> adj) {
        int count=0;
        int n = adj.size();
        Stack<Integer> nodesToCheck = new Stack<>();
        nodesToCheck.addAll(IntStream.range(0,n).boxed().collect(Collectors.toList()));

        while (!nodesToCheck.isEmpty()) {
            count++;
            Integer node = nodesToCheck.pop();
            String adjacent = adj.get(node);
            Set<Integer> otherGroupMembers = new HashSet<>();
            for (Integer i : nodesToCheck) {
                if (adjacent.charAt(i) == '1') otherGroupMembers.add(i);
            }
            nodesToCheck.removeAll(otherGroupMembers);
        }
        return count;
    }
}
Answered By: Onur Akan