DEV Community

Cover image for Advent of Code 2020: Day 14 using bitwise logic in Python
Yuan Gao
Yuan Gao

Posted on • Edited on

Advent of Code 2020: Day 14 using bitwise logic in Python

Today's challenge is a lot of bit twiddling! As a former embedded engineer, this stuff brings back memories

The Challenge Day 1

Link to challenge on Advent of Code 2020 website

The challenge involves applying a tri-state value to a another binary value. The example given is:

mask = XXXXXXXXXXXXXXXXXXXXXXXXXXXXX1XXXX0X
Enter fullscreen mode Exit fullscreen mode

Where any 1 or 0 sets/clears that bit in the value it is applied to, while an X doesn't touch it.

Bitwise OR and AND

One thing that you learn very quickly when dealing with embedded systems, is how to do bit twiddling, because a lot of registers in microcontrollers map individual or ranges of bits to achieve certain things, whether that's doing IO, or setting up the peripherals.

A bitwise OR operation is often used for a SET operation. any 1 in a value when OR'd against another value will set that bit.

A bitwise AND operation is often used for a CLEAR operation. any 0 in a value when AND'd against another value will clear that bit.

So the above example: XXXXXXXXXXXXXXXXXXXXXXXXXXXXX1XXXX0X can be decomposed down to two operations, an OR with 000000000000000000000000000001000000 (the original value where X was replaced with 0) to set that one bit that we want to set and an AND with 111111111111111111111111111111111101 (the original value where X was replaced with 1)

So applying the mask, in terms of python is simply:

mask = "XXXXXXXXXXXXXXXXXXXXXXXXXXXXX1XXXX0X"
set_val = int(mask.replace("X","0"), 2)
clr_val = int(mask.replace("X", "1"), 2)
result = (value | set_val) & clr_val
Enter fullscreen mode Exit fullscreen mode

Leveraging python's ability to do a binary string to int conversion using int()

The rest of our code deals with reading and parse the inputs. As before, we could use PEG grammars, but this time I'll keep the code compact and use regex matching instead.

First we define regexes (regices?) for each of the two kinds of lines

import re
mask_re = re.compile("^mask = ([10X]+)$")
mem_re = re.compile("^mem\[(\d+)\] = (\d+)$")
Enter fullscreen mode Exit fullscreen mode

Next, we loop through all the data, and if it's a mask line, decode that and set our set_val/clr_val variables, otherwise if it's a memory line, decode it and apply the mask to that memory location and store it (using a dictionary to store the memory addresses and their values)

mem = {}
for entry in data:
    mask_line = mask_re.match(entry)
    if mask_line:
        mask = mask_line.groups()[0]
        set_val = int(mask.replace("X","0"), 2)
        clr_val = int(mask.replace("X", "1"), 2)
        continue

    addr, value = mem_re.match(entry).groups()
    mem[addr] = (int(value) | set_val) & clr_val
Enter fullscreen mode Exit fullscreen mode

The challenge asks to find the sum of all memory values, which is simply:

print("sum", sum(mem.values()))
Enter fullscreen mode Exit fullscreen mode

The full code:

import re
mask_re = re.compile("^mask = ([10X]+)$")
mem_re = re.compile("^mem\[(\d+)\] = (\d+)$")

data = open("input.txt").readlines()
mem = {}
for entry in data:
    mask_line = mask_re.match(entry)
    if mask_line:
        mask = mask_line.groups()[0]
        set_val = int(mask.replace("X","0"), 2)
        clr_val = int(mask.replace("X", "1"), 2)
        continue

    addr, value = mem_re.match(entry).groups()
    mem[addr] = (int(value) | set_val) & clr_val
print("sum", sum(mem.values()))
Enter fullscreen mode Exit fullscreen mode

The Challenge Part 2

The second part of the challenge switches up things so that instead of applying the mask to the data, it applies it to the address the data is being written to, but with a twist: for a mask that is XXXXXXXXXXXXXXXXXXXXXXXXXXXXX1XXXX0X it is writing the data to not one but multiple address locations, for every single combination of a 1 or 0 in each X!

Calculating the new address

The first step is to actually calculate the new addres format. We use similar code to before, within our loop, we can use regex to extract the mask. In this case, we don't need a clr_val as the rules for modifying the address don't need it.

    mask_line = mask_re.match(entry)
    if mask_line:
        mask = mask_line.groups()[0]
        set_val = int(mask.replace("X", "0"), 2)
        continue
Enter fullscreen mode Exit fullscreen mode

Then, we take the address, and combine it with the mask to form the new address complete with all the X'es

    addr, value = mem_re.match(entry).groups()
    addr = int(addr) | set_val
    new_addr = "".join([m if m == "X" else a for m, a in zip(mask, f"{addr:036b}")])
Enter fullscreen mode Exit fullscreen mode

Here, the addr has all the high bits of the mask OR'd onto it, and then converted back into a string using python's formatting mini language with f"{addr:036b}" which produces the zero-padded 36-character binary string that matches the mask value format. Then, the list-comprehension sticks an X into it wherever the mask is X, otherwise takes the value from the address.

This leaves us with a new_addr that we must now deal with. This address matches multiple real addresses, and so we have a couple of options:

The naive solution

First, let's brute force it - just straight-up generate every possible address location from this new_addr and write it in memory, and sum the total.

To do this, we'll use numpy's magic of being able to grab all the X'es in the string. Let's use the address 100010X1X0X011X1101XX00011X01010X10X as new_addr

np_addr = np.array(list(new_addr))
xloc = np.argwhere(np_addr == "X")
Enter fullscreen mode Exit fullscreen mode

Output (new_addr= "100010X1X0X011X1101XX00011X01010X10X")

array([[ 6],
       [ 8],
       [10],
       [14],
       [19],
       [20],
       [26],
       [32],
       [35]])
Enter fullscreen mode Exit fullscreen mode

These are the indices of the X characters. There are 9 of them, so there are 512 different addresses possible. For each address, we will take the index number, and create the right binary value. Take index 123 for example:

f"{i:b}".zfill(xloc.size)
Enter fullscreen mode Exit fullscreen mode

Output (i = 123)

'001111011'
Enter fullscreen mode Exit fullscreen mode

This converts the number 123 into its binary representation, and pads it out to 9 chars. We can then have numpy simply update these 9 characters using the indexes found earlier:

np_addr[xloc] = np.vstack(f"{i:b}".zfill(xloc.size))
"".join(np_addr)
Enter fullscreen mode Exit fullscreen mode

Output

'100010010010111110111000110010101101'
Enter fullscreen mode Exit fullscreen mode

Neat eh? So the code for generating all of the possible address combinations is the following (using python's generator style)

def generate_addresses(addr):
    np_addr = np.array(list(addr))
    xloc = np.argwhere(np_addr == "X")
    for i in range(2**xloc.size):
        np_addr[xloc] = np.vstack(f"{i:b}".zfill(xloc.size))
        yield "".join(np_addr)
Enter fullscreen mode Exit fullscreen mode

Combining everything together, the full code for this part of the challenge is:

import re
mask_re = re.compile("^mask = ([10X]+)$")
mem_re = re.compile("^mem\[(\d+)\] = (\d+)$")

data = open("input.txt").readlines()

def generate_addresses(addr):
    np_addr = np.array(list(addr))
    xloc = np.argwhere(np_addr == "X")
    for i in range(2**xloc.size):
        np_addr[xloc] = np.vstack(f"{i:b}".zfill(xloc.size))
        yield "".join(np_addr)

mem = {}
for entry in data:
    mask_line = mask_re.match(entry)
    if mask_line:
        mask = mask_line.groups()[0]
        set_val = int(mask.replace("X", "0"), 2)
        continue

    addr, value = mem_re.match(entry).groups()
    addr = int(addr) | set_val
    new_addr = "".join([m if m == "X" else a for m, a in zip(mask, f"{addr:036b}")])

    for addr_candidate in generate_addresses(new_addr):
        mem[addr_candidate] = int(value)

print("sum", sum(mem.values()))
Enter fullscreen mode Exit fullscreen mode

As it turns out, it's only generating 72923 memory locations, which is relatively small for a computer to deal with. But can we do better?

A more optimized solution

Yes. Since the end result doesn't care about the exact memory values, it only cares about the sum of all the values. It means if our mask is XXXXXXXXXXXXXXXXXXXXXXXXXXXXX1XXXX0X we don't actually need to resolve every single address location and then write into it, we just need to find how many this is (2^34) and just multiply the value by this many times and add it to the sum, rather than having to calculate everything.

The only trick is to find out which memory addresses get overwritten by later changes, and to deal with this appropriately. So for each entry, we could calculate how many addresses it would match, and then work out how many of those were overwritten by later writes, and then multiply the remaining addresses untouched by future writes with the value that goes into them. The sum of all of these memory locations should yield the same total as having brute-forced it.

Calculating address overlap

We can decide whether an address overlaps with another if all the following are true:

  • every 1 bit in the first address matches either a 1 or an X in the second address
  • every 0 bit in the first address matches either a 0 or an X in the second address
  • every X bit in the first address can match anything in the second address

Or actually the reverse is shorter. We know the address DOESN'T overlap if any of the following is true

  • a 1 bit in the first address matches a 0 in the second address
  • a 0 bit in the first address matches a 1 in the second address

So we can update our loop from before to check for overlap with existing memory entries, and start keeping track of what addresses were ovewritten:


mem = {}
for entry in data:
    mask_line = mask_re.match(entry)
    if mask_line:
        mask = mask_line.groups()[0]
        set_val = int(mask.replace("X", "0"), 2)
        continue

    addr, value = mem_re.match(entry).groups()
    addr = int(addr) | set_val
    new_addr = "".join([m if m == "X" else a for m, a in zip(mask, f"{addr:036b}")])

    for old_addr, value_sub in mem.items():
        if not any(pair in (("1", "0"),("0", "1")) for pair in zip(old_addr, new_addr)):
            value_sub.append(new_addr)

    mem[new_addr] = [int(value)]
Enter fullscreen mode Exit fullscreen mode

The end result here is a mem structure that doesn't store real addresses, but instead holds a list of wildcard addresses, and a list that contains the value that should be written into it, and any other wildcard address that is detected to have overlapped with it.

Then, we can modify our generator from before to generate, given a wildcard address and a list of other overlapping addresses, a list of all addresses that overlap:

def generate_subaddresses(addr, others):
    np_addr = np.array(list(addr))
    xloc = np.argwhere(np_addr == "X")

    for other in others:
        np_other = np.array(list(other))
        mini_addr = np_other[xloc].flatten()

        mini_xloc = np.argwhere(mini_addr == "X")
        for i in range(2**mini_xloc.size):
            mini_addr[mini_xloc] = np.vstack(f"{i:b}".zfill(len(mini_xloc)))
            yield "".join(mini_addr)
Enter fullscreen mode Exit fullscreen mode

What's going on here is for a given address, we extract all the X'es from it. And then take those corresponding bit positions from the other addresses, and then if any of those contain an X, generate every possible "mini address" out of that lot. The number of unique addresses in the generated list is how many of the original address was overwritten.

So, now that we have a list of wildcard addresses and other overlapping wildcard addresses, and a function that can count the combinations of overalp, all that remains is to calculate our total:

total = 0
for addr, (value, *others) in mem.items():
    overwritten = len(set(generate_subaddresses(addr, others)))
    total += (2**addr.count("X") - overwritten)*value
total
Enter fullscreen mode Exit fullscreen mode

The full code to this part of the challenge:

import re
mask_re = re.compile("^mask = ([10X]+)$")
mem_re = re.compile("^mem\[(\d+)\] = (\d+)$")

data = open("input.txt").readlines()

def generate_subaddresses(addr, others):
    np_addr = np.array(list(addr))
    xloc = np.argwhere(np_addr == "X")

    for other in others:
        np_other = np.array(list(other))
        mini_addr = np_other[xloc].flatten()

        mini_xloc = np.argwhere(mini_addr == "X")
        for i in range(2**mini_xloc.size):
            mini_addr[mini_xloc] = np.vstack(f"{i:b}".zfill(len(mini_xloc)))
            yield "".join(mini_addr)

mem = {}
for entry in data:
    mask_line = mask_re.match(entry)
    if mask_line:
        mask = mask_line.groups()[0]
        set_val = int(mask.replace("X", "0"), 2)
        continue

    addr, value = mem_re.match(entry).groups()
    addr = int(addr) | set_val
    new_addr = "".join([m if m == "X" else a for m, a in zip(mask, f"{addr:036b}")])

    for old_addr, value_sub in mem.items():
        if not any(pair in (("1", "0"),("0", "1")) for pair in zip(old_addr, new_addr)):
            value_sub.append(new_addr)

    mem[new_addr] = [int(value)]

total = 0
for addr, (value, *others) in mem.items():
    overwritten = len(set(generate_subaddresses(addr, others)))
    total += (2**addr.count("X") - overwritten)*value
print("total", total)
Enter fullscreen mode Exit fullscreen mode

This method is substantially quicker than before, as it doesn't need to generate a list of every single possible memory address, and only a list of overlapping ones. The total is calculated from memory values multiplied by how many memory addresses they would have been written into and not overwritten by later operations.

Onward!

Top comments (0)