How would I add a tqdm or similar progress bar to this multiprocessing script?

Question:

I’ve been trying with tqdm but can’t quite seem to get it right with apply_async. The relevant parts of the code are the ‘start()’ and ‘resume()’ functions. Basically I want to add a progress bar to see when a job is finished.

featurizer:

class UrlFeaturizer(object):
    def __init__(self, url):
        self.url = url
        try:
            self.response = requests.get(
                prepend_protocols(self.url), headers=headers, timeout=5
            )
        except Exception:
            self.response = None
        try:
            self.whois = whois.query(self.url).__dict__
        except Exception:
            self.whois = None
        try:
            self.soup_c = BeautifulSoup(
                self.response.content,
                features="lxml",
                from_encoding=self.response.encoding,
            )
        except Exception:
            self.soup_c = None
        # try:
        #     self.soup_t = BeautifulSoup(self.response.text, features="lxml")
        # except Exception:
        #     self.soup_t = None

    def lookup_whois(self) -> int:
        """
        Look up the age of self.url.

        :return:
            - The domain's age - If WHOIS domains for self.url is available.
            - None - If WHOIS domains for self.url is unavailable.
        """
        return int(False) if self.whois else int(True)

    def lookup_domain_age(self) -> int:
        """
        Look up the age of self.url.

        :return:
            - The domain's age - If WHOIS domains for self.url is available.
            - None - If WHOIS domains for self.url is unavailable.
        """
        if self.whois and self.whois["creation_date"]:
            return (date.today() - self.whois["creation_date"].date()).days
        return

    def verify_ssl(self) -> bool:
        """
        Verify the SSL certificate of self.url.

        :return:
            - True - If self.url's SSL certificate was verified.
            - False - If self.url's SSL certificate could not be verified.
        :raise Exception: If there is an error getting self.url's SSL
        certificate.
        """
        try:
            ssl_cert = ssl.get_server_certificate((self.url, 443), timeout=10)
            return int(True) if ssl_cert else int(False)
        except Exception:
            return

    def check_security(self) -> bool:
        """
        Check whether self.url is using a SSL certificate.

        :return:
            - True - If self.url is using a SSL certificate.
            - False - If self.url isn't using a SSL certificate.
        :raise requests.ConnectionError: If there is an error during Requests
        connecting to self.url.
        """
        try:
            requests.head(f"https://{self.url}", timeout=10)
            return int(True)
        except Exception:
            return int(False)

    def has_com_tld(self):
        """

        :return:
        """
        return int(True) if extract_tld(self.url) == "com" else int(False)

    def count_num_subdomains(self) -> int:
        """
        Count the number of subdomains in self.url.

        :return: The number of subdomains in self.url.
        """
        return self.url.count(".") - 1

    def run(self, dataset=None):
        data = {
            "url": self.url,
            "uses_whois_privacy": self.lookup_whois(),
            "domain_age": self.lookup_domain_age(),
            "has_ssl": self.verify_ssl(),
            "is_secure": self.check_security(),
            "has_com_tld": self.has_com_tld(),
            "num_subdomains": self.count_num_subdomains(),
        }
        return data.keys(), data

data_collection:

"""This module extracts features from a list of URLs."""

import csv
import multiprocessing as mp
import multiprocessing.managers
import os
import pathlib
from utils import test_internet_connection
import pandas as pd
from tqdm.auto import tqdm

from discovery.featurizer import UrlFeaturizer

keys = UrlFeaturizer("1.1.1.1").run("")[0]


def worker(url: str, dataset: str, q: multiprocessing.managers.AutoProxy) -> type(None):
    """
    Pass a URL through UrlNumFeaturizer for feature extraction.

    :param url: The URL passed for feature extraction.
    :param dataset: A dataset containing URLs with the same label.
    :param q: The Queue.
    :return: None
    """
    try:
        res = UrlFeaturizer(url).run(dataset)[1]
        q.put(res)
    except AttributeError:
        pass
    except TimeoutError:
        pass
    return


def listener(q: multiprocessing.managers.AutoProxy) -> type(None):
    """
    Listen for messages on the Queue to determine whether to write to file.

    :param q: The Queue.
    :return: None
    """
    with open("num_features.csv", "a") as f:
        while 1:
            m = q.get()
            if m == "kill":
                break
            csv_out = csv.DictWriter(f, keys)
            csv_out.writerow(m)
            f.flush()
    return


def start(
    pool: multiprocessing.pool.Pool,
    q: multiprocessing.managers.AutoProxy,
    jobs: list[multiprocessing.pool.ApplyResult],
) -> type(None):
    """
    Start collecting data from the first URL in the list.

    :param pool: The process pool.
    :param q: The Queue.
    :param jobs: The list that will contain multiprocessing results.
    :return: None
    """
    datasets = ["benign_domains.csv", "dmca_domains.csv",]
    for dataset in datasets:
        urls = pd.read_csv(dataset, header=None).iloc[:, 0].to_list()
        for url in urls:
            job = pool.apply_async(worker, (url, dataset, q))
            jobs.append(job)
    for count, job in enumerate(jobs):
        if count % 100 == 0:
            if not test_internet_connection():
                os.system("say Internet connection lost.")
                pool.terminate()
                break
        job.get()
    q.put("kill")
    pool.close()
    pool.join()
    sort_csv()
    return


def resume(
    pool: multiprocessing.pool.Pool,
    q: multiprocessing.managers.AutoProxy,
    jobs: list[multiprocessing.pool.ApplyResult],
) -> type(None):
    """
    Resume collecting data from the last iterated URL if the program was
    interrupted.

    :param pool: The process pool.
    :param q: The Queue.
    :param jobs: The list that will contain multiprocessing results.
    :return: None
    """
    processed_urls = (
        pd.read_csv(
            "num_features.csv",
            usecols=[0],
            header=None,
        )
        .iloc[:, 0]
        .to_list()
    )
    datasets = ["benign_domains.csv", "dmca_domains.csv",]
    for dataset in datasets:
        unprocessed_urls = pd.read_csv(dataset, header=None).iloc[:, 0].to_list()
        urls = [item for item in unprocessed_urls if item not in processed_urls]
        for url in urls:
            job = pool.apply_async(worker, (url, dataset, q))
            jobs.append(job)
    for count, job in enumerate(jobs):
        if count % 100 == 0:
            if not test_internet_connection():
                os.system("say Internet connection lost.")
                pool.terminate()
                break
        job.get()
    q.put("kill")
    pool.close()
    pool.join()
    sort_csv()
    return


def write_header_to_csv(log_file: pathlib.PosixPath) -> type(None):
    """
    Write a header to the CSV file if it's not already there.

    :param log_file: The CSV file that the collected data is written to.
    :return: None
    """
    if log_file.is_file():
        pass
    else:
        with open("num_features.csv", "a") as f:
            csv_out = csv.DictWriter(f, keys)
            csv_out.writeheader()
            f.flush()
    return


def sort_csv() -> type(None):
    """
    Group the values in the CSV by the 'label' column.

    :return: None
    """
    df = pd.read_csv("num_features.csv")
    df = df.sort_values(by=["label"])
    df.to_csv("num_features.csv", index=None)
    return


def collect_feature_data() -> type(None):
    """
    Write extracted URL features to file.

    :return: None
    """
    log_file = pathlib.Path("num_features.csv")
    write_header_to_csv(log_file)
    manager = mp.Manager()
    q = manager.Queue()
    pool = mp.Pool(mp.cpu_count() + 2)
    watcher = pool.apply_async(listener, (q,))
    jobs = []
    if log_file.is_file() and os.path.getsize("num_features.csv") > 730:
        resume(pool, q, jobs)
    else:
        start(pool, q, jobs)
    return


if __name__ == "__main__":
    collect_feature_data()
Asked By: ariyasas94

||

Answers:

This is how I would modify the start function (you would do something similar for resume).

Note that I am using a callback with apply_async so that as soon as a submitted task has completed, regardless of what order it was submitted, I can update the progress bar. But I must then first create the progress bar before I start submitting tasks. But to do that I must first know how many tasks will be eventually submitted so that I can correctly initialize the progress bar. That is why I am first collecting all of the arguments to apply_async in variable args (this gives me the count of tasks that will be submitted).

import tqdm

def start(
    pool: multiprocessing.pool.Pool,
    q: multiprocessing.managers.AutoProxy,
    jobs: list[multiprocessing.pool.ApplyResult],
) -> type(None):
    """
    Start collecting data from the first URL in the list.

    :param pool: The process pool.
    :param q: The Queue.
    :param jobs: The list that will contain multiprocessing results.
    :return: None
    """
    datasets = ["benign_domains.csv", "dmca_domains.csv",]
    args = []
    for dataset in datasets:
        urls = pd.read_csv(dataset, header=None).iloc[:, 0].to_list()
        for url in urls:
            args.append((url, dataset, q))
    with tqdm.tqdm(range(len(args))) as pbar:
        for arg in args:
            job = pool.apply_async(worker, arg, callback=lambda result: pbar.update(1))
            jobs.append(job)
        for count, job in enumerate(jobs):
            if count % 100 == 0:
                if not test_internet_connection():
                    os.system("say Internet connection lost.")
                    pool.terminate()
                    break
            job.get()
    q.put("kill")
    pool.close()
    pool.join()
    sort_csv()
    return

You might consider replacing the lambda expression being used for the callback with an actual function and do your calls to test_internet_connection in the callback so that you are actually testing the connection after every 100 completions independent of the order of completion:

import tqdm
gtom threading.import Event

def start(
    pool: multiprocessing.pool.Pool,
    q: multiprocessing.managers.AutoProxy,
    jobs: list[multiprocessing.pool.ApplyResult],
) -> type(None):
    """
    Start collecting data from the first URL in the list.

    :param pool: The process pool.
    :param q: The Queue.
    :param jobs: The list that will contain multiprocessing results.
    :return: None
    """

    completed = 0
    completion = Event()

    def my_callback(result):
        nonlocal completed

        pbar.update()
        completed += 1

        if completed == len(args):
            completion.set()
            pool.close()
        elif completed % 100 == 0:
            if not test_internet_connection():
                os.system("say Internet connection lost.")
                completion.set()
                pool.terminate()

    datasets = ["benign_domains.csv", "dmca_domains.csv",]
    args = []
    for dataset in datasets:
        urls = pd.read_csv(dataset, header=None).iloc[:, 0].to_list()
        for url in urls:
            args.append((url, dataset, q))
    with tqdm.tqdm(range(len(args))) as pbar:
        for arg in args:
            job = pool.apply_async(worker, arg, callback=my_callback)
            jobs.append(job)
        completion.wait()
        pool.join()
    q.put("kill")
    sort_csv()
    return
Answered By: Booboo