Skip to content

Instantly share code, notes, and snippets.

@Filiprogrammer
Created November 5, 2024 12:40
Show Gist options
  • Select an option

  • Save Filiprogrammer/6767c6c89a963cf9d728bd5ae6d69e44 to your computer and use it in GitHub Desktop.

Select an option

Save Filiprogrammer/6767c6c89a963cf9d728bd5ae6d69e44 to your computer and use it in GitHub Desktop.
Implementation of AES encryption and decryption with S-Box generation
#!/usr/bin/env python3
# S-Box is generated by generate_s_box()
S_BOX = [None] * 256
# Inverse S-Box is generated by generate_inverse_s_box()
S_BOX_INVERSE = [None] * 256
MIX_COLUMNS_MATRIX = [
0x02, 0x03, 0x01, 0x01,
0x01, 0x02, 0x03, 0x01,
0x01, 0x01, 0x02, 0x03,
0x03, 0x01, 0x01, 0x02
]
MIX_COLUMNS_MATRIX_INVERSE = [
0x0E, 0x0B, 0x0D, 0x09,
0x09, 0x0E, 0x0B, 0x0D,
0x0D, 0x09, 0x0E, 0x0B,
0x0B, 0x0D, 0x09, 0x0E
]
RC = [0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1B, 0x36]
def g(wi, round_number):
b = wi.to_bytes(4, 'little')
left_shifted = [b[1], b[2], b[3], b[0]]
output = [None] * 4
for i in range(4):
output[i] = S_BOX[left_shifted[i]]
output[0] ^= RC[round_number]
return int.from_bytes(output, byteorder='little')
def generate_round_keys(key):
# Hardcoded with 10 rounds and 128 bit key for now
w = [None] * 44
w[0] = int.from_bytes(key[:4], byteorder='little')
w[1] = int.from_bytes(key[4:8], byteorder='little')
w[2] = int.from_bytes(key[8:12], byteorder='little')
w[3] = int.from_bytes(key[12:16], byteorder='little')
for i in range(10):
w[4 * (i + 1) + 0] = w[4 * i + 0] ^ g(w[4 * i + 3], i)
w[4 * (i + 1) + 1] = w[4 * i + 1] ^ w[4 * (i + 1) + 0]
w[4 * (i + 1) + 2] = w[4 * i + 2] ^ w[4 * (i + 1) + 1]
w[4 * (i + 1) + 3] = w[4 * i + 3] ^ w[4 * (i + 1) + 2]
round_keys = [[]] * 11
for i in range(11):
round_key = []
for j in range(4):
round_key.extend(w[4 * i + j].to_bytes(4, 'little'))
round_keys[i] = round_key
return round_keys
def substitute_bytes(input, sbox):
substituted = [None] * 16
for i in range(16):
substituted[i] = sbox[input[i]]
return substituted
def shift_rows(input):
return [
input[0], input[5], input[10], input[15],
input[4], input[9], input[14], input[3],
input[8], input[13], input[2], input[7],
input[12], input[1], input[6], input[11]
]
def shift_rows_inverse(input):
return [
input[0], input[13], input[10], input[7],
input[4], input[1], input[14], input[11],
input[8], input[5], input[2], input[15],
input[12], input[9], input[6], input[3]
]
def get_most_significant_bit_number(a):
n = -1
while a != 0:
a >>= 1
n += 1
return n
def gf_mod_reduce(a, m):
m_msb = get_most_significant_bit_number(m)
while a >= (1 << m_msb):
a_msb = get_most_significant_bit_number(a)
msb_diff = a_msb - m_msb
a ^= (m << msb_diff)
return a
def rijndael_mul(a, b):
result = 0
for i in range(8):
if (a & (1 << i)) == 0:
continue
result ^= (b << i)
result = gf_mod_reduce(result, 0x11B)
return result
def multiplicative_inverse(a):
if a == 0:
return 0
for i in range(256):
if rijndael_mul(a, i) == 1:
return i
raise Exception("Found no multiplicative inverse")
def generate_single_s_box_value(a):
b = multiplicative_inverse(a)
b_bits = [int(x) for x in '{0:08b}'.format(b)]
affine_transformation_matrix = [
[1, 0, 0, 0, 1, 1, 1, 1],
[1, 1, 0, 0, 0, 1, 1, 1],
[1, 1, 1, 0, 0, 0, 1, 1],
[1, 1, 1, 1, 0, 0, 0, 1],
[1, 1, 1, 1, 1, 0, 0, 0],
[0, 1, 1, 1, 1, 1, 0, 0],
[0, 0, 1, 1, 1, 1, 1, 0],
[0, 0, 0, 1, 1, 1, 1, 1]
]
c_bits = [0] * 8
for i in range(8):
for j in range(8):
c_bits[7 - i] ^= affine_transformation_matrix[i][j] * b_bits[7 - j]
c = int("".join(str(x) for x in c_bits), 2)
d = c ^ 0b01100011
return d
def generate_s_box():
for i in range(256):
S_BOX[i] = generate_single_s_box_value(i)
def generate_inverse_s_box():
for i in range(256):
S_BOX_INVERSE[S_BOX[i]] = i
def matrix_mul_vector(m, v):
output = [None] * 4
for i in range(4):
output[i] = rijndael_mul(m[i * 4], v[0]) ^ rijndael_mul(m[i * 4 + 1], v[1]) ^ rijndael_mul(m[i * 4 + 2], v[2]) ^ rijndael_mul(m[i * 4 + 3], v[3])
return output
def mix_columns(state, mix_columns_matrix):
for i in range(4):
state[i * 4:i * 4 + 4] = matrix_mul_vector(mix_columns_matrix, state[i * 4:i * 4 + 4])
return state
def add_round_key(input, round_key):
output = [None] * 16
for i in range(16):
output[i] = input[i] ^ round_key[i]
return output
def aes_round(input, round_key, is_last_round):
# Substitute bytes
state = substitute_bytes(input, S_BOX)
# Shift rows
state = shift_rows(state)
if not is_last_round:
# Mix columns
state = mix_columns(state, MIX_COLUMNS_MATRIX)
# Add round key
state = add_round_key(state, round_key)
return state
def aes_inverse_round(input, round_key, is_last_round):
# Inverse shift rows
state = shift_rows_inverse(input)
# Inverse sub bytes
state = substitute_bytes(state, S_BOX_INVERSE)
# Add round key
state = add_round_key(state, round_key)
if not is_last_round:
# Inverse mix cols
state = mix_columns(state, MIX_COLUMNS_MATRIX_INVERSE)
return state
def encrypt(plaintext, key):
plaintext_bytes = bytearray()
plaintext_bytes.extend(map(ord, plaintext))
key_bytes = bytearray()
key_bytes.extend(map(ord, key))
round_keys = generate_round_keys(key_bytes)
# Initial transformation
state = add_round_key(plaintext_bytes, round_keys[0])
# 10 rounds
for i in range(1, 10):
state = aes_round(state, round_keys[i], False)
state = aes_round(state, round_keys[10], True)
state_str = ''.join(map(chr, state))
return state_str
def decrypt(ciphertext, key):
ciphertext_bytes = bytearray()
ciphertext_bytes.extend(map(ord, ciphertext))
key_bytes = bytearray()
key_bytes.extend(map(ord, key))
round_keys = generate_round_keys(key_bytes)
# Initial transformation
state = add_round_key(ciphertext_bytes, round_keys[10])
# 10 rounds
for i in range(9, 0, -1):
state = aes_inverse_round(state, round_keys[i], False)
state = aes_inverse_round(state, round_keys[0], True)
state_str = ''.join(map(chr, state))
return state_str
if __name__ == "__main__":
print("Generating S-Box...")
generate_s_box()
print("Done")
print("Generating inverse S-Box...")
generate_inverse_s_box()
print("Done")
PLAINTEXT = "Hello plaintext!"
KEY = "Super secret key"
print("PLAINTEXT:\t" + repr(PLAINTEXT) + "\t(" + ":".join("{:02x}".format(ord(c)) for c in PLAINTEXT) + ")")
print("KEY:\t\t\t" + repr(KEY) + "\t(" + ":".join("{:02x}".format(ord(c)) for c in KEY) + ")")
ciphertext = encrypt(PLAINTEXT, KEY)
print("ENCRYPTED:\t" + repr(ciphertext) + "\t(" + ":".join("{:02x}".format(ord(c)) for c in ciphertext) + ")")
decrypted = decrypt(ciphertext, KEY)
print("DECRYPTED:\t" + repr(decrypted) + "\t(" + ":".join("{:02x}".format(ord(c)) for c in decrypted) + ")")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment