Writing unit tests for the async message queue server

Marton Trencseni - Sat 15 June 2024 - Python

Introduction

In previous posts, I wrote a simple async message queue server in Python, C++ and Javascript:

I plan to write final versions of the C++ and Javascript codebase, and make it feature and wire compatible with the final Python version. One easy way to make sure they are in fact wire compatible is to have a library of unit tests. The code is on Github.

Minor changes to the server

To make the Python server testable, I made some minor changes, and added the ability to pass the cache_size as a command-line parameter. This allows starting the server with a low cache_size, so less messages can be use to test the cache semantics:

if __name__ == "__main__":
    if len(sys.argv) not in [1, 2, 3]:
        logging.info(f"Usage: python3 aiomemq.py <port> <cache_size>")
        logging.info(f"  <port>       - optional, default {DEFAULT_PORT}")
        logging.info(f"  <cache_size> - optional, default {DEFAULT_CACHE_SIZE}")
        sys.exit(1)
    port = int(sys.argv[1]) if len(sys.argv) <= 3 else DEFAULT_PORT
    cache_size = int(sys.argv[2]) if len(sys.argv) == 3 else DEFAULT_CACHE_SIZE
    caches = defaultdict(lambda: deque(maxlen=cache_size))
    asyncio.run(run_server(host="localhost", port=port))

Helper functions

To make the unit tests modular, I wrote a small set of helper functions to send and receive messages over the network. Getting these functions "right" was harder than I initially thought, and I kept modifying them as I added more tests. I ran into the following 2 main complexities:

  1. When testing subscribe messages, the server will respond with a success message, and then 0..N cached messages. So, although with most messages we send we just want to receive one success (or failure) message, sometimes there may be more messages incoming.
  2. Sometimes, the test is "don't get anything from the server", for example when subscribing to a channel where there should be no cached messages. For cases like this I added a timeout, so the test can wait for some short time like 0.1s (but long for a computer) to make sure it doesn't get a message.

In the end I ended up with these helpers to receive messages:

def receive(sock):
    response = sock.recv(64*1024).decode('utf-8').strip()
    return json.loads(response)

def receive_many(sock, timeout=None, allow_trailing_bytes=False):
    sock.settimeout(timeout)
    response = ''
    try:
        response = sock.recv(64*1024).decode('utf-8').strip()
    except TimeoutError:
        pass
    sock.settimeout(None)
    messages = response.split("\r\n")
    response = []
    if allow_trailing_bytes:
        for msg in messages:
            try:
                response.append(json.loads(msg))
            except:
                return response
    else:
        response = [json.loads(msg) for msg in messages if msg]
    return response

def receive_until(sock, expected_count, timeout=0.05, allow_trailing_bytes=False):
    end_time = time.time() + timeout
    responses = []
    while time.time() < end_time:
        try:
            responses.extend(receive_many(sock, timeout, allow_trailing_bytes))
            if len(responses) >= expected_count:
                break
        except socket.timeout:
            break
    return responses

.. and these to send and receive messages:

def send(sock, message):
    if isinstance(message, str):
        sock.sendall((message + "\r\n").encode('utf-8'))
    elif isinstance(message, bytes):
        sock.sendall(message + b"\r\n")
    else:
        sock.sendall((json.dumps(message) + "\r\n").encode('utf-8'))

def send_and_receive(sock, message):
    send(sock, message)
    time.sleep(0.01)
    return receive(sock)

def send_and_receive_many(sock, message, allow_trailing_bytes=False):
    send(sock, message)
    time.sleep(0.01)
    return receive_many(sock, timeout=None, allow_trailing_bytes=allow_trailing_bytes)

def send_and_receive_until(sock, message, expected_count, timeout=0.05, allow_trailing_bytes=False):
    send(sock, message)
    time.sleep(0.01)
    return receive_until(sock, expected_count, timeout, allow_trailing_bytes)

Server fixture

First I created a server fixture to launch the server. A fixture is essentially a shared dependency resource for the tests:

@pytest.fixture(scope="module")
def server():
    # Start the server as a separate process
    server_process = subprocess.Popen(['python3', 'aiomemq.py', str(SERVER_PORT), str(CACHE_SIZE)])
    time.sleep(1)  # Give the server some time to start
    yield
    server_process.terminate()
    server_process.wait()

Unit tests: send random bytes

Any network server needs to be able to receive random strings and bytes (instead of valid messages) and not crash:

def _test_random_bytes(length):
    random_bytes = generate_random_bytes(length)
    with socket.create_connection((SERVER_HOST, SERVER_PORT)) as client:
        response = send_and_receive_many(client, random_bytes, allow_trailing_bytes=True)
        # the random bytes could generate multiple responses
        for r in response: r == {'success': False, 'reason': 'Could not decode input as UTF-8'}


def test_random_bytes_100(server):
    return _test_random_bytes(100)

def test_random_bytes_1k(server):
    return _test_random_bytes(1024)

def test_random_bytes_10k(server):
    return _test_random_bytes(10*1024)

...

Unit test: connect many clients

On a similar note, a network server needs to be able to handle a lot of client connections and not crash. Since our server is async, this should be no problem:

def _test_many_connections(server, num_clients, num_messages):
    clients = []
    topics = []
    expected_counts = defaultdict(int)
    # Step 1: All clients connect and create a topic for themselves
    for i in range(num_clients):
        client = socket.create_connection((SERVER_HOST, SERVER_PORT))
        topic = generate_random_topic()
        topics.append(topic)
        clients.append((client, topic))
        # Subscribe to its own topic
        response = send_and_receive(client, {'command': 'subscribe', 'topic': topic})
        assert response == {'success': True}
    # Step 2: All clients select another random client and send it K messages
    for sender_id in range(num_clients):
        sender, sender_topic = clients[sender_id]
        for _ in range(num_messages):
            recipient_id = random.choice([i for i in range(num_clients) if i != sender_id])
            recipient_topic = topics[recipient_id]
            message = {'command': 'send', 'topic': recipient_topic, 'msg': f'test message from {sender_id} to {recipient_id}', 'delivery': 'all'}
            # Just send the message, don't read responses, since sent messages could also be arriving
            send(sender, message)
            expected_counts[sender_id] += 1    # the {'success': True}
            expected_counts[recipient_id] += 1 # the actual message
    time.sleep(1)
    # Step 3: Clients ensure they only receive messages addressed to them
    total_received_messages = 0
    for recipient_id in range(num_clients):
        client, topic = clients[recipient_id]
        # received_messages = receive_many(client)
        received_messages = receive_until(client, expected_count=expected_counts[recipient_id], timeout=0.1, allow_trailing_bytes=False)
        assert len(received_messages) == expected_counts[recipient_id]
        assert {'success': True} in received_messages
        # Check that all received messages were addressed to this client
        for msg in received_messages:
            if 'command' in msg:
                assert msg['command'] == 'send'
                assert msg['topic'] == topic
                total_received_messages += 1
    # Close all client connections
    for client, _ in clients:
        client.close()
    # Step 4: Check that in total, all N*K messages were received
    assert total_received_messages == num_clients * num_messages

def test_many_connections_10_1(server):
    _test_many_connections(server, num_clients=10, num_messages=1)

def test_many_connections_10_10(server):
    _test_many_connections(server, num_clients=10, num_messages=10)

def test_many_connections_10_100(server):
    _test_many_connections(server, num_clients=10, num_messages=100)

...

def test_many_connections_10k_10(server):
    _test_many_connections(server, num_clients=10*1000, num_messages=10)

The server can handle 10k client connections!

Unit test: message validation

Are messages validated correctly by the server?

def test_subscribe_validation(server):
    topic = generate_random_topic()
    with socket.create_connection((SERVER_HOST, SERVER_PORT)) as client1:
        # missing key
        response = send_and_receive(client1, {'command': 'subscribe'})
        assert response == {'success': False, 'reason': 'Malformed json message'}
        # extra key
        response = send_and_receive(client1, {'command': 'subscribe', 'topic': topic, 'extra_key': 'extra_value'})
        assert response == {'success': False, 'reason': 'Malformed json message'}
        # bad type
        response = send_and_receive(client1, {'command': 123, 'topic': topic})
        assert response == {'success': False, 'reason': 'Malformed json message'}
        # bad type
        response = send_and_receive(client1, {'command': 'subscribe', 'topic': 123})
        assert response == {'success': False, 'reason': 'Malformed json message'}
        # bad type
        response = send_and_receive(client1, {'command': 'subscribe', 'topic': topic, 'last_seen': "123"})
        assert response == {'success': False, 'reason': 'Malformed json message'}
        # bad type
        response = send_and_receive(client1, {'command': 'subscribe', 'topic': topic, 'cache': 123})
        assert response == {'success': False, 'reason': 'Malformed json message'}

def test_unsubscribe_validation(server):
    ...

def test_send_validation(server):
    ...

Unit test: test quoting in JSON

Are messages with weird quoting processed correctly?

def test_quoting(server):
    strings_with_quotes = ["''''''", '""""""', "'\"'\"'\""]
    with socket.create_connection((SERVER_HOST, SERVER_PORT)) as client1:
        for s in strings_with_quotes:
            # Subscribe to the topic with quotes
            response = send_and_receive(client1, {'command': 'subscribe', 'topic': s})
            assert response == {'success': True}
            # Send a message to the topic with quotes
            message = {'command': 'send', 'topic': s, 'msg': s, 'delivery': 'all'}
            response = send_and_receive_until(client1, message, 2)
            assert {'success': True} in response
            assert {
                'command': 'send',
                'topic': s,
                'msg': s,
                'index': 0,
                'delivery': 'all'
            } in response

Summary of unit tests

In summary, I wrote 500+ lines of ~40 unit tests:

  1. Random Bytes Tests: Validates the server's response to random byte inputs of varying lengths, ensuring that invalid inputs are correctly rejected with an appropriate error message.
  2. Random String Tests: Checks the server's handling of random string inputs of different lengths, verifying that malformed JSON messages are correctly identified and rejected.
  3. Message Validation Tests: Ensures that the server correctly handles various cases of malformed commands, including missing keys, extra keys, and incorrect data types.
  4. Non-Existing Commands: Validates that the server correctly identifies and rejects commands that are not recognized, returning an appropriate error message.
  5. Quoting Tests: Verifies that the server correctly processes subscribe and send commands with topics containing single quotes, double quotes, and mixed quotes.
  6. Topic Length Tests: Checks the server's ability to handle topics of varying lengths, from very short (1 character) to very long (1024 characters).
  7. Concurrent Clients Tests: Tests the server's handling of multiple concurrent clients, ensuring that messages are delivered according to the specified delivery mode (all or one).
  8. Many Connections Tests: Stress tests the server by simulating a large number of clients (10, 100, 1000, 10k) each sending messages to random clients, ensuring that all messages are correctly received and processed.
  9. Cache Behavior: Verifies the server's handling of the cache parameter in subscribe commands, ensuring that messages are delivered or withheld based on the cache setting.
  10. Cache Size Tests: Ensures that the server correctly handles cache size limits by verifying that only the most recent messages are delivered when the cache size is exceeded.
  11. Delivery Semantics: Validates the server's delivery modes (all and one) to ensure messages are delivered according to the specified delivery mode.
  12. Last Seen Behavior: Tests the last_seen parameter in subscribe commands, ensuring that only messages with an index higher than last_seen are delivered to the subscriber.

Conclusion

The point of writing tests is to specify the behaviour of the application and find bugs. In this case, I only found one server bug around message parsing code, as seen in this commit. It's interesting to note that using ChatGPT was very effective for writing these unit tests. Altough it regularly made mistakes which I had to fix, especially around calling the right helper function (eg. should it be receive_many() or receive_until()), but using it as a faux pair programming partner was still useful.