I recently I had the opportunity to speed up some Websocket code that was a major bottleneck. The final solution was 60X (!) faster than the first pass, and an interesting exercise in optimizing inner loops.

When you send data via a websocket it is masked with a randomly generated four byte key. This masking provides protection from certain exploits against proxies (details here). The algorithm is super simple; essentially each byte in the payload is XORed with a corresponding byte from the key.

Solution 1

The first version of my XOR mask function looked this:

from itertools import cycle
def mask1(mask, data):
    return bytes(a ^ b for a, b in zip(cycle(mask), data))

That's the kind of unfussy and elegant code you want to write for a job interview. If that's the solution that first came to mind, then you know your Python. Alas, this code is as slow as it is beautiful.

It is slow because every byte runs through the Python interpreter. For a chat server, it won't matter. But if you are sending megabytes of data over a websocket, chances are you won't be able to mask websocket packets as fast as the network can accept them.

Solution 2

One option to speed this up is to use wsaccel to do the masking, which is much faster due to a tight inner loop written in C. Here's the code for that:

from wsaccel.xormask import XorMaskerSimple
def mask2(mask, data):
    return XorMaskerSimple(mask).process(data)

If you can install wsaccel on your system, then use that. There is no reason not to. Unfortunately, the project I'm working on needs to be able to run a websocket client on systems where it isn't always easy to install C extensions. But without some serious optimization, the pure-python solution is going to be painfully slow. Is there a faster solution than mask1 in pure-Python?

Solution 3

Turns out there is. The following is a rather clever solution from the aiohttp project which runs significantly faster:

native_byteorder = sys.byteorder
def mask3(mask, data):
    datalen = len(data)
    data = int.from_bytes(data, native_byteorder)
    mask = int.from_bytes(mask * (datalen // 4) + mask[: datalen % 4], native_byteorder)
    return (data ^ mask).to_bytes(datalen, native_byteorder)

As awkward as it looks, this solution blows mask1 out of the water in terms of speed. Essentially it treats the websocket payload and mask as a very large number, does the XOR as a single operation, then turns the result back in to bytes. It's fast because the work happens in a single operation, rather than per-byte.

Solution 4

I might have been happy with that, if it worked with Python2.7. Unfortunately, int.from_bytes was introduced in Python3.2 and my code is intended to run on Python2.7 and 3.X. I needed another trick that runs as fast, but works on both major versions of Python.

This is what I came up with:

_XOR_TABLE = [bytes(a ^ b for a in range(256)) for b in range(256)]
def mask4(mask, data):
    data_bytes = bytearray(data)
    a, b, c, d = (_XOR_TABLE[n] for n in mask)
    data_bytes[::4] = data_bytes[::4].translate(a)
    data_bytes[1::4] = data_bytes[1::4].translate(b)
    data_bytes[2::4] = data_bytes[2::4].translate(c)
    data_bytes[3::4] = data_bytes[3::4].translate(d)
    return bytes(data_bytes)

A quick profile indicated that mask4 was faster than mask3 and approaching the same speed as the C version. I really didn't expect such a big win.

The trick here is to pre-calculate the results of the XOR operation in a table. This table contains every combination of two bytes XORed together. So _XOR_TABLE[a][b] is equivalent to a ^ b for every possible value of bytes a and b.

We can now use bytearray.translate to replace payload bytes with a corresponding byte from the table . You can be forgiven if you've never seen (or forgotten) the translate method; it's always been on the Python2 string type and is inherited by bytearray objects.

The reason translate is called four times, is that it can only do a replace for a single row in the table, and we need to pick a row in the table for each byte in the mask. Slices are used to extract and replace the appropriate bytes for each byte of mask.

Speed Comparison

I did some profiling to compare the various versions of the masking function. The pretty pure-python version was so slow that it was impossible to make a meaningful comparison, so I've removed it from the following chart:

This graph was generated on my Macbook, which is way more powerful that the devices the code was written for.

The X axis is the number of bytes in the Websocket payload.

The Y axis is the time taken (milliseconds) to mask the payload.

The translate version is only slightly slower than the C version, which is pleasing. Although it is worth pointing out that up until around 2500 bytes, the from_bytes version is slightly faster than translate.

Conclusion?

The takeaway from this is that explicit inner loops in Python can be slow. If there is a built-in method that operates on a number of elements in a single call, then it will likely be faster. Even if that means doing additional work to prepare.

The caveat is that this level of micro-optimization it is rarely worth the effort. The final version would probably earn you a frowny face in a code review (if it wasn't accompanied by a comment explaining the reasoning behind it). In this instance however, the ugliness is probably justifiable.

The code for the masking functions (which also generates the above graph) is available here.

Update

I've added the masking code from aiohttp to the graph. It is significantly faster than wsaccel because it does the XOR operation 64 bits at a time, compared to wsaccel's 8 bits at a time.

aiohttp has masking code written in Cython, which masks 64bits at a time for an impressive speed-up.

Use Markdown for formatting
*Italic* **Bold** `inline code` Links to [Google](http://www.google.com) > This is a quote > ```python import this ```
your comment will be previewed here
gravatar
Ajasja Ljubetič

Shouldn't it be

return bytes(data_bytes)

in mask4?

gravatar
Will McGugan

It should. Thanks for the correction.

gravatar
Joao Matos

Did you try using numpy bitwise_xor or numpy.ndarray.xor?

https://docs.scipy.org/doc/numpy/reference/generated/numpy.bitwise_xor.html https://docs.scipy.org/doc/numpy/reference/generated/numpy.ndarray.xor.html

What's the byte size of _XOR_TABLE?

gravatar
Will McGugan

Nope. I needed a pure-python solution. I suspect numpy would preform well though. _XOR_TABLE is 64K data, plus python overhead.

gravatar
13steinj

Couldn't you just have used the struct module to do what aiohttp did using int.from_bytes in Py2?

gravatar
Will McGugan

May be possible, but as far as I know struct can't create arbitrarily long integers. So you would have to use a loop, and you would lose any performance gain...