RabbitMQ Python PIKA – using add_callback_threadsafe lost acknowledgement when stopping program

Question:

I’m using pika to process RabitMQ message by small batch, and using a thread for each batch.
At the end of the function in the thread, I send acknowledgement of the messages through add_callback_threadsafe to the channel.

In parallele I’m catching SIGINT signals to stop the program properly, by waiting with thread.join() that all threads finish before stopping the channel consume and closing the connection.

But once the CtrlC is sent to generate the SIGINT, event if the program wait for all threads to finish, the acknowledgement will not be processed.

==> is there a way to force the channel/connection to process the waiting add_callback_threadsafe before closing the connection ?

enter image description here

# import packages
# connect to Rabbit MQ
import pika
# intercept stop signal
import signal
# print exception
import traceback
# threading
import functools
import threading
from queue import Queue
# logs time
import datetime
import time

# Function Message Acknowledgement
def ack_message(ch, delivery_tag):
    """Note that `ch` must be the same pika channel instance via which
    the message being ACKed was retrieved (AMQP protocol constraint).
    """
    print(f'DEBUG ack_message : begining of ack_message function')

    if ch.is_open:
        ch.basic_ack(delivery_tag)
        print(f'DEBUG ack_message : Acknowledgement delivered')
    else:
        # Channel is already closed, so we can't ACK this message;
        # log and/or do something that makes sense for your app in this case.
        print(datetime.datetime.now(),str(datetime.timedelta(seconds=time.time() - init_time)),f'ERROR Channel Closed when trying to Acknowledge')
        pass

# Function Process multiple messages in separate thread 
def block_process():
    # list global variables to be changed
    global channel
    # init local variables
    body_list = list()
    tag_list = list()

    print(f'DEBUG block_process : start of block_process function')

    # cancel the timer if exist, as we will proces all elements in the queue here
    if event and event.isAlive():
        event.cancel()

    # extract all queued messages fom internal python queue and rebuild individual listes body and tag from tupple
    for i in range(list_Boby_Tag.qsize()):
        myTuppleBodyTag = list_Boby_Tag.get()
        body_list += [myTuppleBodyTag[0]]
        tag_list += [myTuppleBodyTag[1]]
    # that also empty the queue

    # do something that take time with the block of nessage in body_list
    time.sleep(10)
    for body in body_list:
        body_str = body.decode()
        print(f'DEBUG block_process : message processed is {body_str}')

    # acknowledging all tags in tag_list by using the channel thread safe function .connection.add_callback_threadsafe
    for tag in tag_list:
        print(f'DEBUG preprare delivering Acknowledgement from thread')
        cb = functools.partial(ack_message, channel, tag)
        channel.connection.add_callback_threadsafe(cb)

    print(f'DEBUG block_process : end of block_process function')

    return

# Function Process message by message and call 
def process_message(ch, method, properties, body):
    # list global variables to be changed
    global list_Boby_Tag
    global event
    global threads

    # do nothing if this flag is on, as the program is about to close
    if PauseConsume == 1:
        return
    
    # cancel the timer if exist as we are going to process a block or restart a new timer
    if event and event.isAlive():
        event.cancel()

    # put in the queue the data from the body and tag as tupple
    list_Boby_Tag.put((body,method.delivery_tag))

    # if a max queue size is reached (here 500), immmediately launch a new thread to process the queue
    if list_Boby_Tag.qsize() == 500 :
        #print(f'DEBUG thread count before {len(threads)}')
        # keep in the threads list only the thread still running
        threads = [x for x in threads if x.is_alive()]
        #print(f'DEBUG thread count after {len(threads)}')
        # start the inference in a separated thread
        t = threading.Thread(target=block_process)
        t.start()
        # keep trace of the thread so it can be waited at the end if still running
        threads.append(t)
        #print(f'DEBUG thread count after add {len(threads)}')
    elif list_Boby_Tag.qsize() > 0 :
        # if the queue is not full create a thread with a timer to do the process after sometime, here 10 seconds for test purpose
        event = threading.Timer(interval=10, function=block_process)
        event.start()
        # also add this thread to the list of threads
        threads.append(event)

# PARAMETERS
RabbitMQ_host = '192.168.1.190'
RabbitMQ_port = 5672
RabbitMQ_queue = 'test_ctrlC'
RabbitMQ_cred_un = 'xxxx'
RabbitMQ_cred_pd = 'xxxx'

# init variables for batch process
list_Boby_Tag = Queue()
threads = list()
event = None
PauseConsume = 0
init_time = time.time()

# connect to RabbitMQ via Pika
cred = pika.credentials.PlainCredentials(RabbitMQ_cred_un,RabbitMQ_cred_pd)
connection = pika.BlockingConnection(pika.ConnectionParameters(host=RabbitMQ_host, port=RabbitMQ_port, credentials=cred))
channel = connection.channel()
channel.queue_declare(queue=RabbitMQ_queue,durable=True)
# tell rabbitMQ to don't dispatch a new message to a worker until it has processed and acknowledged the previous one :
channel.basic_qos(prefetch_count=1)

# define the comsumer
channel.basic_consume(queue=RabbitMQ_queue,
                      auto_ack=False, # false = need message acknowledgement : basic_ack in the callback
                      on_message_callback=process_message)

# empty queue and generate test data
channel.queue_purge(queue=RabbitMQ_queue)
# wait few second so the purge can be check in the RabbitMQ ui
print(f'DEBUG main : queue {RabbitMQ_queue} purged')
connection.sleep(10)
# generate 10 test messages
for msgId in range(10):
    channel.basic_publish(exchange='',
                        routing_key=RabbitMQ_queue,
                        body=f'message{msgId}',
                        properties=pika.BasicProperties(
                            delivery_mode = pika.spec.PERSISTENT_DELIVERY_MODE
                        ))
print(f'DEBUG main : test messages created in {RabbitMQ_queue}')

# Function clean stop of pika connection in case of interruption or exception
def cleanClose():
    # tell the on_message_callback to do nothing 
    PauseConsume = 1
    # Wait for all threads to complete
    for thread in threads:
        thread.join()
    # stop pika connection after a short pause
    connection.sleep(3)
    channel.stop_consuming()
    connection.close()
    return

# Function handle exit signals
def exit_handler(signum, frame):
    print(datetime.datetime.now(),str(datetime.timedelta(seconds=time.time() - init_time)),f'Exit signal received ({signum})')
    cleanClose()
    exit(0)

signal.signal(signal.SIGINT, exit_handler) # send by a CTRL+C or modified Docker Stop
#signal.signal(signal.SIGTSTP, exit_handler) # send by a CTRL+Z Docker Stop

print(' [*] Waiting for messages. To exit press CTRL+C')
try:
    channel.start_consuming()
except Exception:
    print(datetime.datetime.now(),str(datetime.timedelta(seconds=time.time() - init_time)),f'Exception received within start_consumming')
    traceback.print_exc()
    cleanClose()

Asked By: Stephane

||

Answers:

A workaround has been found by Luke here :
https://github.com/lukebakken/pika-1402/blob/lukebakken/pika-1402/test_pika_blockthread.py

He has changed the sample code with :

  • Simplified it a bit by removing the "batch processing via Queue" code since it wasn’t related to the current issue
  • Moved the Pika connection to its own thread
  • Instead of using a consume callback, move to a generator-style for loop which allows for checking if exiting is requested. This could also be accomplished via SelectConnection and a timer.

Sample code with the batch processing via Queue added back, matching the original code is here :

# from https://github.com/lukebakken/pika-1402/blob/lukebakken/pika-1402/test_pika_blockthread.py
# import packages
# connect to Rabbit MQ
import pika
import pika.credentials
import pika.spec
# intercept stop signal
import signal
# print exception
# import traceback
# threading
import functools
import threading
from queue import Queue

# logs time
import datetime
import time

# PARAMETERS
RabbitMQ_host = "192.168.1.190"
RabbitMQ_port = 5672
RabbitMQ_queue = "test_ctrlC"
RabbitMQ_cred_un = "xxxx"
RabbitMQ_cred_pd = "xxxx"

nbTest = 100000
nbBatch = 1000
nbPrefetch = 10000
# note, prefecth always >= nbBatch
timerSec = 60 # timer wait
workSec = 5 # nbr of sec for simulating batch work

# init variables for batch process
init_time = time.time()
exiting = False
work_threads = list()
event = None
list_Boby_Tag = Queue()

# Function Message Acknowledgement
def ack_message(ch, delivery_tag):
    """Note that `ch` must be the same pika channel instance via which
    the message being ACKed was retrieved (AMQP protocol constraint).
    """
    if ch.is_open:
        ch.basic_ack(delivery_tag)
        #print(datetime.datetime.now(),str(datetime.timedelta(seconds=time.time() - init_time)),f"DEBUG ack_message : begining of ack_message function, tag: {delivery_tag}")
    else:
        # Channel is already closed, so we can't ACK this message;
        # log and/or do something that makes sense for your app in this case.
        print(datetime.datetime.now(),str(datetime.timedelta(seconds=time.time() - init_time)),"Channel Closed when trying to Acknowledge")
        pass
    return

# Function Process multiple messages in separate thread
def do_work(channel, list_Boby_Tag):
    # init local variables
    body_list = list()
    tag_list = list()

    # cancel the timer if exist, as we will proces all elements in the queue here
    if event and event.is_alive():
        event.cancel()

    # extract all queued messages fom internal python queue and rebuild individual listes body and tag from tupple
    for i in range(list_Boby_Tag.qsize()):
        myTuppleBodyTag = list_Boby_Tag.get()
        body_list += [myTuppleBodyTag[0]]
        tag_list += [myTuppleBodyTag[1]]
    # that also empty the queue

    # do something that take time with the block of nessage in body_list
    time.sleep(workSec)

    for body in body_list:
        body_str = body.decode()
        #print(datetime.datetime.now(),str(datetime.timedelta(seconds=time.time() - init_time)),f'DEBUG block_process : message processed is {body_str}')

    # acknowledging all tags in tag_list by using the channel thread safe function .connection.add_callback_threadsafe
    for tag in tag_list:
        cb = functools.partial(ack_message, channel, tag)
        channel.connection.add_callback_threadsafe(cb)

    return

# Function Process message by message and call thread by block or timer
def process_message(ch, method, body):
    global work_threads
    global list_Boby_Tag
    global event

    # cancel the timer if exist as we are going to process a block or restart a new timer
    if event and event.is_alive():
        event.cancel()

    # put in the queue the data from the body and tag as tupple
    list_Boby_Tag.put((body,method.delivery_tag))

    # if a max queue size is reached (here 500), immmediately launch a new thread to process the queue
    if list_Boby_Tag.qsize() == nbBatch :
        #print(f'DEBUG thread count before {len(threads)}')
        # keep in the threads list only the thread still running
        work_threads = [x for x in work_threads if x.is_alive()]
        #print(f'DEBUG thread count after {len(threads)}')
        # start the inference in a separated thread
        t = threading.Thread(target=do_work, args=(ch, list_Boby_Tag))
        t.start()
        # keep trace of the thread so it can be waited at the end if still running
        work_threads.append(t)
        #print(f'DEBUG thread count after add {len(threads)}')
    elif list_Boby_Tag.qsize() > 0 :
        # if the queue is not full create a thread with a timer to do the process after sometime, here 10 seconds for test purpose
        event = threading.Timer(interval=timerSec, function=do_work, args=(ch, list_Boby_Tag))
        event.start()
        # also add this thread to the list of threads
        work_threads.append(event)
    return

# Function to start pika channel and stopping it
def pika_runner():
    # connect to RabbitMQ via Pika
    cred = pika.credentials.PlainCredentials(RabbitMQ_cred_un, RabbitMQ_cred_pd)
    connection = pika.BlockingConnection(
        pika.ConnectionParameters(
            host=RabbitMQ_host, port=RabbitMQ_port, credentials=cred
        )
    )
    channel = connection.channel()
    channel.queue_declare(queue=RabbitMQ_queue, durable=True)
    # tell rabbitMQ to don't dispatch a new message to a worker until it has processed and acknowledged the previous one :
    channel.basic_qos(prefetch_count=nbPrefetch)

    # empty queue and generate test data
    channel.queue_purge(queue=RabbitMQ_queue)
    # wait few second so the purge can be check in the RabbitMQ ui
    print(datetime.datetime.now(),str(datetime.timedelta(seconds=time.time() - init_time)),f"DEBUG main : queue {RabbitMQ_queue} purged, sleeping 5 seconds")
    connection.sleep(5)
    # generate test messages
    for msgId in range(nbTest):
        channel.basic_publish(
            exchange="",
            routing_key=RabbitMQ_queue,
            body=f"message-{msgId+1}",
            properties=pika.BasicProperties(
                delivery_mode=pika.spec.PERSISTENT_DELIVERY_MODE
            ),
        )
    print(datetime.datetime.now(),str(datetime.timedelta(seconds=time.time() - init_time)),f"DEBUG main : test messages created in {RabbitMQ_queue}")

    # loop forever to retrieve messahes
    for method_frame, properties, body in channel.consume(
        queue=RabbitMQ_queue, inactivity_timeout=1, auto_ack=False
    ):
        if exiting:
            print(datetime.datetime.now(),str(datetime.timedelta(seconds=time.time() - init_time)),f"DEBUG : stopping consuming")
            #channel.stop_consuming()
            channel.cancel()
            print(datetime.datetime.now(),str(datetime.timedelta(seconds=time.time() - init_time)),f"DEBUG : joining work threads")
            for thread in work_threads:
                thread.join()
            print(datetime.datetime.now(),str(datetime.timedelta(seconds=time.time() - init_time)),f"DEBUG : all work threads done, sleeping 5 seconds to let acks be delivered")
            connection.sleep(5)
            print(datetime.datetime.now(),str(datetime.timedelta(seconds=time.time() - init_time)),f"DEBUG : closing connections and channels")
            channel.close()
            connection.close()
        else:
            if method_frame is not None:
                process_message(channel, method_frame, body)

    return


# Function handle exit signals
def exit_handler(signum, _):
    global exiting
    if exiting:
        return
    exiting = True
    print(datetime.datetime.now(),str(datetime.timedelta(seconds=time.time() - init_time)),"Exit signal received")
    pika_thread.join()
    exit(0)

# launch the thread that will connect and listen to Pika
pika_thread = threading.Thread(target=pika_runner)
pika_thread.start()

# catch interuption signal to exit gracefully
signal.signal(signal.SIGINT, exit_handler)  # send by a CTRL+C or modified Docker Stop
print(" [*] Waiting for messages. To exit press CTRL+C")

# wait for all threads to finish
for thread in work_threads:
    thread.join()
pika_thread.join()

Answered By: Stephane