Secure Dot Product

[crypto]

Our intern thought it was a great idea to vibe code a secure dot product server using our AES key. Having taken a class in linear algebra, they’re confident the server can’t ever leak our key, but I’m not so sure…

file

picoCTF{n0t_so_s3cure_.x_w1th_sh@512_[***]}

Summary

We have access to a server that will compute dot products against its secret AES key, but only for “trusted” vectors.

The server prints five trusted vectors and their salted SHA-512 hashes, then lets us submit vectors as long as we also provide the matching salted hash.

However, the server has two bugs that together let us recover the AES key:

Understanding the Encryption and remote.py

We are given remote.py, which reads from a flag.txt on the remote server. A random 32-byte key and 16-byte initialization vector (IV) is generated. The flag is encrypted with AES.MODE_CBC using the key and IV.

What is AES.MODE_CBC?

CBC (Cipher Block Chaining) splits the padded plaintext into 16-byte blocks (P1, P2, P3…) and encrypts each block at a time. The ciphertext of the previous block is XOR’ed with the next plaintext before encryption with the key. So, given plaintext chunks P1, P2, P3…, ciphertext chunks C1, C2, C3…

C1 = encrypt(P1 XOR IV, key)
C2 = encrypt(P2 XOR C1, key)
C3 = encrypt(P3 XOR C2, key)

To decrypt, we need the key and IV, both of which the server gives us upfront. So all we need to do is recover the key.

The dot product oracle

The 32-byte AES key is stored as a list of integers:

self.key_vector = [byte for byte in key]

The server generates five random trusted vectors and computes their dot products with key_vector. It prints each trusted vector alongside its salted hash:

def hash_vector(self, vector):
    vector_encoding = vector[1:-1].encode('latin-1')
    return hashlib.sha512(self.salt + vector_encoding).digest().hex()

So the hash covers salt || vector[1:-1] — the raw string content of the vector, without the outer brackets. The salt is 256 bytes long and secret.

When we submit a vector, the server verifies our hash matches before computing the dot product:

def compute_dot(self, vector):
    parsed = self.parse_vector(vector)
    return sum(a * b for a, b in zip(parsed, self.key_vector))

Every accepted query gives us one linear equation: dot(our_vector, key). If we could freely choose vectors, we could recover the key in 32 queries. The hash check is supposed to prevent that.

Parsing the input

def parse_vector(self, vector):
    sanitized = "".join(c if c in '0123456789,[]' else '' for c in vector)
    ...

Only digits, commas, and square brackets are not sanitized. Everything else is stripped, including minus signs. So if the server trusted the vector [-5, 10, -7], submitting it back would parse as [5, 10, 7].

We can also see how the server reads input:

vector_input = input("Enter your vector: ")
vector_input = vector_input.encode().decode('unicode_escape')

It interprets escape sequences before doing anything else. So if we send \x80\x00, the server sees the raw bytes 0x80 0x00. This is important for smuggling SHA-512 padding bytes through.

Vulnerability Analysis

1. SHA-512 length extension

The server authenticates vectors using SHA512(salt || message). This is called a secret-prefix MAC, and it’s insecure with Merkle-Damgård hashes like SHA-512.

SHA-512 processes data in 128-byte blocks. After hashing a message, the internal state is exactly the digest. If you know a digest and the length of the original input, you can continue hashing from that state without needing the salt.

Specifically, given SHA512(salt || message) and len(salt || message), we can compute:

SHA512(salt || message || glue_padding || extra)

where glue_padding is the SHA-512 padding that was appended to salt || message internally. The server will accept this hash because it’s valid as long as we extend the hash without knowing the salt.

2. Input sanitation issue

The server hashes the raw vector string but evaluates a sanitized version.

We can inject bytes like 0x80, 0x00. Almost all of these are stripped by the sanitizer because they’re not in 0123456789,[].

So we can forge a hash for message || glue_padding || extra, and the server will parse it as roughly sanitize(message) || sanitize(extra).

Recovering the Key

Step 1: Forge a baseline query for each trusted vector

For each trusted (vector, digest) pair, we length-extend the hash to cover message || glue_padding with no extra suffix. We send this to the server and get back:

baseline = P[0]*k[0] + P[1]*k[1] + ... + P[m-1]*k[m-1]

where P is the parsed prefix (what survives sanitization) and k is the AES key. We compute P locally using the same sanitizer, so we know it exactly.

Among the five forged baselines, we pick the one whose parsed prefix P is shortest. Call that length m.

The shorter the prefix, the fewer key bytes are locked into the baseline and the more we can control with appended tails.

If m > 5 we can’t solve the system with only 5 equations, so we reconnect and try again with a new random instance.

Step 2: Recover key bytes m through 31

Using the pivot vector (shortest prefix), we append tails like [0, 0, ..., 1] where the 1 lands at position j. The server returns:

result = baseline + 1 * k[j]

So:

k[j] = result - baseline

We do this for every j from m to 31, recovering the later key bytes one at a time.

Step 3: Solve the first m key bytes

Each of the five baselines gives us an equation:

R_i = P_i[0]*k[0] + ... + P_i[m-1]*k[m-1] + (known suffix terms)

We subtract the known suffix terms (using the key bytes we already recovered) to isolate the first m unknowns:

R_i' = P_i[0]*k[0] + ... + P_i[m-1]*k[m-1]

This is a linear system with m unknowns and up to 5 equations. We solve it exactly using sympy.Matrix.LUsolve. Each solution value must be an integer in [0, 255] to be a valid key byte.

Once we have all 32 key bytes, standard AES-CBC decryption gives us the flag.

Solve script

#!/usr/bin/env python3
import ast
from itertools import combinations

from Crypto.Cipher import AES
from Crypto.Util.Padding import unpad
from pwn import context, process, remote
from sympy import Matrix

HOST = "***"
PORT = ***
KEY_SIZE = 32
SALT_SIZE = 256
ALLOWED_CHARS = "0123456789,[]"
MASK64 = (1 << 64) - 1

K = [
    0x428A2F98D728AE22,
    0x7137449123EF65CD,
    0xB5C0FBCFEC4D3B2F,
    0xE9B5DBA58189DBBC,
    0x3956C25BF348B538,
    0x59F111F1B605D019,
    0x923F82A4AF194F9B,
    0xAB1C5ED5DA6D8118,
    0xD807AA98A3030242,
    0x12835B0145706FBE,
    0x243185BE4EE4B28C,
    0x550C7DC3D5FFB4E2,
    0x72BE5D74F27B896F,
    0x80DEB1FE3B1696B1,
    0x9BDC06A725C71235,
    0xC19BF174CF692694,
    0xE49B69C19EF14AD2,
    0xEFBE4786384F25E3,
    0x0FC19DC68B8CD5B5,
    0x240CA1CC77AC9C65,
    0x2DE92C6F592B0275,
    0x4A7484AA6EA6E483,
    0x5CB0A9DCBD41FBD4,
    0x76F988DA831153B5,
    0x983E5152EE66DFAB,
    0xA831C66D2DB43210,
    0xB00327C898FB213F,
    0xBF597FC7BEEF0EE4,
    0xC6E00BF33DA88FC2,
    0xD5A79147930AA725,
    0x06CA6351E003826F,
    0x142929670A0E6E70,
    0x27B70A8546D22FFC,
    0x2E1B21385C26C926,
    0x4D2C6DFC5AC42AED,
    0x53380D139D95B3DF,
    0x650A73548BAF63DE,
    0x766A0ABB3C77B2A8,
    0x81C2C92E47EDAEE6,
    0x92722C851482353B,
    0xA2BFE8A14CF10364,
    0xA81A664BBC423001,
    0xC24B8B70D0F89791,
    0xC76C51A30654BE30,
    0xD192E819D6EF5218,
    0xD69906245565A910,
    0xF40E35855771202A,
    0x106AA07032BBD1B8,
    0x19A4C116B8D2D0C8,
    0x1E376C085141AB53,
    0x2748774CDF8EEB99,
    0x34B0BCB5E19B48A8,
    0x391C0CB3C5C95A63,
    0x4ED8AA4AE3418ACB,
    0x5B9CCA4F7763E373,
    0x682E6FF3D6B2B8A3,
    0x748F82EE5DEFB2FC,
    0x78A5636F43172F60,
    0x84C87814A1F0AB72,
    0x8CC702081A6439EC,
    0x90BEFFFA23631E28,
    0xA4506CEBDE82BDE9,
    0xBEF9A3F7B2C67915,
    0xC67178F2E372532B,
    0xCA273ECEEA26619C,
    0xD186B8C721C0C207,
    0xEADA7DD6CDE0EB1E,
    0xF57D4F7FEE6ED178,
    0x06F067AA72176FBA,
    0x0A637DC5A2C898A6,
    0x113F9804BEF90DAE,
    0x1B710B35131C471B,
    0x28DB77F523047D84,
    0x32CAAB7B40C72493,
    0x3C9EBE0A15C9BEBC,
    0x431D67C49C100D4C,
    0x4CC5D4BECB3E42B6,
    0x597F299CFC657E2A,
    0x5FCB6FAB3AD6FAEC,
    0x6C44198C4A475817,
]


def rotr(value, shift):
    return ((value >> shift) | (value << (64 - shift))) & MASK64


def sha512_padding(message_len):
    zeros = (112 - (message_len + 1) % 128) % 128
    return b"\x80" + (b"\x00" * zeros) + (message_len * 8).to_bytes(16, "big")


def sha512_compress(state, block):
    words = [0] * 80
    for index in range(16):
        start = index * 8
        words[index] = int.from_bytes(block[start:start + 8], "big")

    for index in range(16, 80):
        sigma0 = rotr(words[index - 15], 1) ^ rotr(words[index - 15], 8) ^ (words[index - 15] >> 7)
        sigma1 = rotr(words[index - 2], 19) ^ rotr(words[index - 2], 61) ^ (words[index - 2] >> 6)
        words[index] = (words[index - 16] + sigma0 + words[index - 7] + sigma1) & MASK64

    a, b, c, d, e, f, g, h = state
    for index in range(80):
        sum1 = rotr(e, 14) ^ rotr(e, 18) ^ rotr(e, 41)
        choice = (e & f) ^ (~e & g)
        temp1 = (h + sum1 + choice + K[index] + words[index]) & MASK64
        sum0 = rotr(a, 28) ^ rotr(a, 34) ^ rotr(a, 39)
        majority = (a & b) ^ (a & c) ^ (b & c)
        temp2 = (sum0 + majority) & MASK64

        h = g
        g = f
        f = e
        e = (d + temp1) & MASK64
        d = c
        c = b
        b = a
        a = (temp1 + temp2) & MASK64

    return [
        (state[0] + a) & MASK64,
        (state[1] + b) & MASK64,
        (state[2] + c) & MASK64,
        (state[3] + d) & MASK64,
        (state[4] + e) & MASK64,
        (state[5] + f) & MASK64,
        (state[6] + g) & MASK64,
        (state[7] + h) & MASK64,
    ]


def sha512_length_extend(digest_hex, original_total_len, extra):
    state = [int(digest_hex[offset:offset + 16], 16) for offset in range(0, 128, 16)]
    prior_len = original_total_len + len(sha512_padding(original_total_len))
    data = extra + sha512_padding(prior_len + len(extra))

    for offset in range(0, len(data), 128):
        state = sha512_compress(state, data[offset:offset + 128])

    return "".join(f"{value:016x}" for value in state)


def parse_like_service(decoded_text):
    sanitized = "".join(char if char in ALLOWED_CHARS else "" for char in decoded_text)
    try:
        parsed = ast.literal_eval(sanitized)
    except (SyntaxError, ValueError, TypeError):
        return None
    return parsed if isinstance(parsed, list) else None


def escape_for_unicode_escape(data):
    return "".join(f"\\x{byte:02x}" for byte in data).encode()


def build_query(vector, extra_tail):
    message = str(vector)[1:-1].encode("latin-1")
    glue_padding = sha512_padding(SALT_SIZE + len(message))
    extra = b""
    if extra_tail:
        extra = ("," + ",".join(map(str, extra_tail))).encode()

    raw = b"[" + message + glue_padding + extra + b"]"
    parsed = parse_like_service(raw.decode("latin-1"))
    return raw, parsed, extra


def recv_instance(io):
    iv = None
    ciphertext = None
    trusted_pairs = []

    while len(trusted_pairs) < 5:
        line = io.recvline().decode().rstrip("\n")
        if line.startswith("IV: "):
            iv = bytes.fromhex(line.split(": ", 1)[1])
        elif line.startswith("Ciphertext: "):
            ciphertext = bytes.fromhex(line.split(": ", 1)[1])
        elif line.startswith("(["):
            trusted_pairs.append(ast.literal_eval(line))

    return iv, ciphertext, trusted_pairs


def query(io, vector_text, digest_hex):
    io.sendlineafter(b"Enter your vector: ", vector_text)
    io.sendlineafter(b"Enter its salted hash: ", digest_hex.encode())
    line = io.recvline().decode().strip()
    if not line.startswith("The computed dot product is: "):
        raise ValueError(f"unexpected server response: {line}")
    return int(line.split(": ", 1)[1])


def recover_key(io, trusted_pairs):
    bases = []

    for vector, digest in trusted_pairs:
        raw, parsed, extra = build_query(vector, [])
        if parsed is None or not parsed:
            raise ValueError("forged baseline did not parse as a usable vector")

        digest_ext = sha512_length_extend(digest, SALT_SIZE + len(str(vector)[1:-1]), extra)
        result = query(io, escape_for_unicode_escape(raw), digest_ext)
        bases.append(
            {
                "vector": vector,
                "digest": digest,
                "prefix": parsed,
                "baseline": result,
            }
        )

    pivot_base = min(bases, key=lambda item: len(item["prefix"]))
    pivot_len = len(pivot_base["prefix"])
    if pivot_len > 5:
        raise ValueError(f"instance not solvable enough, smallest prefix is {pivot_len}")

    key = [None] * KEY_SIZE
    baseline = pivot_base["baseline"]

    for index in range(pivot_len, KEY_SIZE):
        tail = [0] * (index - pivot_len) + [1]
        raw, _, extra = build_query(pivot_base["vector"], tail)
        digest_ext = sha512_length_extend(
            pivot_base["digest"],
            SALT_SIZE + len(str(pivot_base["vector"])[1:-1]),
            extra,
        )
        result = query(io, escape_for_unicode_escape(raw), digest_ext)
        key[index] = result - baseline

    rows = []
    rhs = []
    for base in bases:
        rows.append(base["prefix"][:pivot_len])
        known_suffix = sum(
            base["prefix"][index] * key[index]
            for index in range(pivot_len, min(len(base["prefix"]), KEY_SIZE))
        )
        rhs.append(base["baseline"] - known_suffix)

    solved_prefix = None
    for picked in combinations(range(len(rows)), pivot_len):
        matrix = Matrix([rows[index] for index in picked])
        if matrix.det() == 0:
            continue

        solution = matrix.LUsolve(Matrix([rhs[index] for index in picked]))
        candidate = []
        valid = True
        for value in solution:
            if not value.is_integer:
                valid = False
                break
            as_int = int(value)
            if not 0 <= as_int <= 255:
                valid = False
                break
            candidate.append(as_int)

        if valid:
            solved_prefix = candidate
            break

    if solved_prefix is None:
        raise ValueError("could not solve the leading key bytes for this instance")

    key[:pivot_len] = solved_prefix
    return bytes(key)


def decrypt_flag(key, iv, ciphertext):
    cipher = AES.new(key, AES.MODE_CBC, iv)
    return unpad(cipher.decrypt(ciphertext), AES.block_size).decode()


def solve_once(io):
    iv, ciphertext, trusted_pairs = recv_instance(io)
    key = recover_key(io, trusted_pairs)
    return decrypt_flag(key, iv, ciphertext)


def connect(mode):
    if mode == "remote":
        return remote(HOST, PORT)
    if mode == "local":
        return process(["python", "remote.py"])
    raise ValueError("mode must be 'local' or 'remote'")


def main():
    context.log_level = "error"
    attempts = 0

    while True:
        attempts += 1
        io = connect("remote")
        try:
            flag = solve_once(io)
            print(flag)
            print(f"attempts={attempts}")
            return
        except Exception:
            io.close()


if __name__ == "__main__":
    main()