Writing unit tests for the async message queue server
Marton Trencseni - Sat 15 June 2024 - Python
Introduction
In previous posts, I wrote a simple async message queue server in Python, C++ and Javascript:
- Writing a simple Python async message queue server - Part I
- Writing a simple Python async message queue server - Part II
- Writing a simple Python async message queue server - Part III
- Writing a simple C++ async message queue server - Part I
- Writing a simple C++ async message queue server - Part II
- Writing a simple Javascript async message queue server - Part I
- Writing a simple Javascript async message queue server - Part II
I plan to write final versions of the C++ and Javascript codebase, and make it feature and wire compatible with the final Python version. One easy way to make sure they are in fact wire compatible is to have a library of unit tests. The code is on Github.
Minor changes to the server
To make the Python server testable, I made some minor changes, and added the ability to pass the cache_size
as a command-line parameter. This allows starting the server with a low cache_size
, so less messages can be use to test the cache semantics:
if __name__ == "__main__":
if len(sys.argv) not in [1, 2, 3]:
logging.info(f"Usage: python3 aiomemq.py <port> <cache_size>")
logging.info(f" <port> - optional, default {DEFAULT_PORT}")
logging.info(f" <cache_size> - optional, default {DEFAULT_CACHE_SIZE}")
sys.exit(1)
port = int(sys.argv[1]) if len(sys.argv) <= 3 else DEFAULT_PORT
cache_size = int(sys.argv[2]) if len(sys.argv) == 3 else DEFAULT_CACHE_SIZE
caches = defaultdict(lambda: deque(maxlen=cache_size))
asyncio.run(run_server(host="localhost", port=port))
Helper functions
To make the unit tests modular, I wrote a small set of helper functions to send and receive messages over the network. Getting these functions "right" was harder than I initially thought, and I kept modifying them as I added more tests. I ran into the following 2 main complexities:
- When testing
subscribe
messages, the server will respond with asuccess
message, and then 0..N cached messages. So, although with most messages we send we just want to receive onesuccess
(or failure) message, sometimes there may be more messages incoming. - Sometimes, the test is "don't get anything from the server", for example when subscribing to a channel where there should be no cached messages. For cases like this I added a
timeout
, so the test can wait for some short time like 0.1s (but long for a computer) to make sure it doesn't get a message.
In the end I ended up with these helpers to receive messages:
def receive(sock):
response = sock.recv(64*1024).decode('utf-8').strip()
return json.loads(response)
def receive_many(sock, timeout=None, allow_trailing_bytes=False):
sock.settimeout(timeout)
response = ''
try:
response = sock.recv(64*1024).decode('utf-8').strip()
except TimeoutError:
pass
sock.settimeout(None)
messages = response.split("\r\n")
response = []
if allow_trailing_bytes:
for msg in messages:
try:
response.append(json.loads(msg))
except:
return response
else:
response = [json.loads(msg) for msg in messages if msg]
return response
def receive_until(sock, expected_count, timeout=0.05, allow_trailing_bytes=False):
end_time = time.time() + timeout
responses = []
while time.time() < end_time:
try:
responses.extend(receive_many(sock, timeout, allow_trailing_bytes))
if len(responses) >= expected_count:
break
except socket.timeout:
break
return responses
.. and these to send and receive messages:
def send(sock, message):
if isinstance(message, str):
sock.sendall((message + "\r\n").encode('utf-8'))
elif isinstance(message, bytes):
sock.sendall(message + b"\r\n")
else:
sock.sendall((json.dumps(message) + "\r\n").encode('utf-8'))
def send_and_receive(sock, message):
send(sock, message)
time.sleep(0.01)
return receive(sock)
def send_and_receive_many(sock, message, allow_trailing_bytes=False):
send(sock, message)
time.sleep(0.01)
return receive_many(sock, timeout=None, allow_trailing_bytes=allow_trailing_bytes)
def send_and_receive_until(sock, message, expected_count, timeout=0.05, allow_trailing_bytes=False):
send(sock, message)
time.sleep(0.01)
return receive_until(sock, expected_count, timeout, allow_trailing_bytes)
Server fixture
First I created a server
fixture to launch the server. A fixture is essentially a shared dependency resource for the tests:
@pytest.fixture(scope="module")
def server():
# Start the server as a separate process
server_process = subprocess.Popen(['python3', 'aiomemq.py', str(SERVER_PORT), str(CACHE_SIZE)])
time.sleep(1) # Give the server some time to start
yield
server_process.terminate()
server_process.wait()
Unit tests: send random bytes
Any network server needs to be able to receive random strings and bytes (instead of valid messages) and not crash:
def _test_random_bytes(length):
random_bytes = generate_random_bytes(length)
with socket.create_connection((SERVER_HOST, SERVER_PORT)) as client:
response = send_and_receive_many(client, random_bytes, allow_trailing_bytes=True)
# the random bytes could generate multiple responses
for r in response: r == {'success': False, 'reason': 'Could not decode input as UTF-8'}
def test_random_bytes_100(server):
return _test_random_bytes(100)
def test_random_bytes_1k(server):
return _test_random_bytes(1024)
def test_random_bytes_10k(server):
return _test_random_bytes(10*1024)
...
Unit test: connect many clients
On a similar note, a network server needs to be able to handle a lot of client connections and not crash. Since our server is async, this should be no problem:
def _test_many_connections(server, num_clients, num_messages):
clients = []
topics = []
expected_counts = defaultdict(int)
# Step 1: All clients connect and create a topic for themselves
for i in range(num_clients):
client = socket.create_connection((SERVER_HOST, SERVER_PORT))
topic = generate_random_topic()
topics.append(topic)
clients.append((client, topic))
# Subscribe to its own topic
response = send_and_receive(client, {'command': 'subscribe', 'topic': topic})
assert response == {'success': True}
# Step 2: All clients select another random client and send it K messages
for sender_id in range(num_clients):
sender, sender_topic = clients[sender_id]
for _ in range(num_messages):
recipient_id = random.choice([i for i in range(num_clients) if i != sender_id])
recipient_topic = topics[recipient_id]
message = {'command': 'send', 'topic': recipient_topic, 'msg': f'test message from {sender_id} to {recipient_id}', 'delivery': 'all'}
# Just send the message, don't read responses, since sent messages could also be arriving
send(sender, message)
expected_counts[sender_id] += 1 # the {'success': True}
expected_counts[recipient_id] += 1 # the actual message
time.sleep(1)
# Step 3: Clients ensure they only receive messages addressed to them
total_received_messages = 0
for recipient_id in range(num_clients):
client, topic = clients[recipient_id]
# received_messages = receive_many(client)
received_messages = receive_until(client, expected_count=expected_counts[recipient_id], timeout=0.1, allow_trailing_bytes=False)
assert len(received_messages) == expected_counts[recipient_id]
assert {'success': True} in received_messages
# Check that all received messages were addressed to this client
for msg in received_messages:
if 'command' in msg:
assert msg['command'] == 'send'
assert msg['topic'] == topic
total_received_messages += 1
# Close all client connections
for client, _ in clients:
client.close()
# Step 4: Check that in total, all N*K messages were received
assert total_received_messages == num_clients * num_messages
def test_many_connections_10_1(server):
_test_many_connections(server, num_clients=10, num_messages=1)
def test_many_connections_10_10(server):
_test_many_connections(server, num_clients=10, num_messages=10)
def test_many_connections_10_100(server):
_test_many_connections(server, num_clients=10, num_messages=100)
...
def test_many_connections_10k_10(server):
_test_many_connections(server, num_clients=10*1000, num_messages=10)
The server can handle 10k client connections!
Unit test: message validation
Are messages validated correctly by the server?
def test_subscribe_validation(server):
topic = generate_random_topic()
with socket.create_connection((SERVER_HOST, SERVER_PORT)) as client1:
# missing key
response = send_and_receive(client1, {'command': 'subscribe'})
assert response == {'success': False, 'reason': 'Malformed json message'}
# extra key
response = send_and_receive(client1, {'command': 'subscribe', 'topic': topic, 'extra_key': 'extra_value'})
assert response == {'success': False, 'reason': 'Malformed json message'}
# bad type
response = send_and_receive(client1, {'command': 123, 'topic': topic})
assert response == {'success': False, 'reason': 'Malformed json message'}
# bad type
response = send_and_receive(client1, {'command': 'subscribe', 'topic': 123})
assert response == {'success': False, 'reason': 'Malformed json message'}
# bad type
response = send_and_receive(client1, {'command': 'subscribe', 'topic': topic, 'last_seen': "123"})
assert response == {'success': False, 'reason': 'Malformed json message'}
# bad type
response = send_and_receive(client1, {'command': 'subscribe', 'topic': topic, 'cache': 123})
assert response == {'success': False, 'reason': 'Malformed json message'}
def test_unsubscribe_validation(server):
...
def test_send_validation(server):
...
Unit test: test quoting in JSON
Are messages with weird quoting processed correctly?
def test_quoting(server):
strings_with_quotes = ["''''''", '""""""', "'\"'\"'\""]
with socket.create_connection((SERVER_HOST, SERVER_PORT)) as client1:
for s in strings_with_quotes:
# Subscribe to the topic with quotes
response = send_and_receive(client1, {'command': 'subscribe', 'topic': s})
assert response == {'success': True}
# Send a message to the topic with quotes
message = {'command': 'send', 'topic': s, 'msg': s, 'delivery': 'all'}
response = send_and_receive_until(client1, message, 2)
assert {'success': True} in response
assert {
'command': 'send',
'topic': s,
'msg': s,
'index': 0,
'delivery': 'all'
} in response
Summary of unit tests
In summary, I wrote 500+ lines of ~40 unit tests:
- Random Bytes Tests: Validates the server's response to random byte inputs of varying lengths, ensuring that invalid inputs are correctly rejected with an appropriate error message.
- Random String Tests: Checks the server's handling of random string inputs of different lengths, verifying that malformed JSON messages are correctly identified and rejected.
- Message Validation Tests: Ensures that the server correctly handles various cases of malformed commands, including missing keys, extra keys, and incorrect data types.
- Non-Existing Commands: Validates that the server correctly identifies and rejects commands that are not recognized, returning an appropriate error message.
- Quoting Tests: Verifies that the server correctly processes subscribe and send commands with topics containing single quotes, double quotes, and mixed quotes.
- Topic Length Tests: Checks the server's ability to handle topics of varying lengths, from very short (1 character) to very long (1024 characters).
- Concurrent Clients Tests: Tests the server's handling of multiple concurrent clients, ensuring that messages are delivered according to the specified delivery mode (all or one).
- Many Connections Tests: Stress tests the server by simulating a large number of clients (10, 100, 1000, 10k) each sending messages to random clients, ensuring that all messages are correctly received and processed.
- Cache Behavior: Verifies the server's handling of the cache parameter in
subscribe
commands, ensuring that messages are delivered or withheld based on the cache setting. - Cache Size Tests: Ensures that the server correctly handles cache size limits by verifying that only the most recent messages are delivered when the cache size is exceeded.
- Delivery Semantics: Validates the server's delivery modes (
all
andone
) to ensure messages are delivered according to the specified delivery mode. - Last Seen Behavior: Tests the
last_seen
parameter insubscribe
commands, ensuring that only messages with anindex
higher thanlast_seen
are delivered to the subscriber.
Conclusion
The point of writing tests is to specify the behaviour of the application and find bugs. In this case, I only found one server bug around message parsing code, as seen in this commit. It's interesting to note that using ChatGPT was very effective for writing these unit tests. Altough it regularly made mistakes which I had to fix, especially around calling the right helper function (eg. should it be receive_many()
or receive_until()
), but using it as a faux pair programming partner was still useful.