diff --git a/lib/crypt/milenage.c b/lib/crypt/milenage.c index 229dca66d..df8bba2ad 100644 --- a/lib/crypt/milenage.c +++ b/lib/crypt/milenage.c @@ -22,10 +22,12 @@ #define os_memcmp memcmp #define os_memcmp_const memcmp -int aes_128_encrypt_block(const uint8_t *key, - const uint8_t *in, uint8_t *out); +static void ShiftBits(uint8_t r, uint8_t rijndaelInput[16], + uint8_t temp[16], const uint8_t opc[16]); +static uint8_t *bits_shift(uint32_t bit_valid, uint8_t *dst, + uint8_t *src, uint32_t numBits); -int aes_128_encrypt_block(const uint8_t *key, +static int aes_128_encrypt_block(const uint8_t *key, const uint8_t *in, uint8_t *out) { const int key_bits = 128; @@ -55,8 +57,10 @@ int milenage_f1(const uint8_t *opc, const uint8_t *k, { uint8_t tmp1[16], tmp2[16], tmp3[16]; int i; +#if 1 /* R1-R5 issues1153 */ + uint8_t r1 = 64; +#endif - /* tmp1 = TEMP = E_K(RAND XOR OP_C) */ for (i = 0; i < 16; i++) tmp1[i] = _rand[i] ^ opc[i]; if (aes_128_encrypt_block(k, tmp1, tmp1)) @@ -70,8 +74,12 @@ int milenage_f1(const uint8_t *opc, const uint8_t *k, /* OUT1 = E_K(TEMP XOR rot(IN1 XOR OP_C, r1) XOR c1) XOR OP_C */ /* rotate (tmp2 XOR OP_C) by r1 (= 0x40 = 8 bytes) */ +#if 0 /* R1-R5 issues1153 */ for (i = 0; i < 16; i++) tmp3[(i + 8) % 16] = tmp2[i] ^ opc[i]; +#else + ShiftBits(r1, tmp3, tmp2, opc); +#endif /* XOR with TEMP = E_K(RAND XOR OP_C) */ for (i = 0; i < 16; i++) tmp3[i] ^= tmp1[i]; @@ -109,6 +117,13 @@ int milenage_f2345(const uint8_t *opc, const uint8_t *k, uint8_t tmp1[16], tmp2[16], tmp3[16]; int i; +#if 1 /* R1-R5 issues1153 */ + uint8_t r2 = 0; + uint8_t r3 = 32; + uint8_t r4 = 64; + uint8_t r5 = 96; +#endif + /* tmp2 = TEMP = E_K(RAND XOR OP_C) */ for (i = 0; i < 16; i++) tmp1[i] = _rand[i] ^ opc[i]; @@ -122,8 +137,12 @@ int milenage_f2345(const uint8_t *opc, const uint8_t *k, /* f2 and f5 */ /* rotate by r2 (= 0, i.e., NOP) */ +#if 0 /* R1-R5 issues1153 */ for (i = 0; i < 16; i++) tmp1[i] = tmp2[i] ^ opc[i]; +#else + ShiftBits(r2, tmp1, tmp2, opc); +#endif tmp1[15] ^= 1; /* XOR c2 (= ..01) */ /* f5 || f2 = E_K(tmp1) XOR OP_c */ if (aes_128_encrypt_block(k, tmp1, tmp3)) @@ -138,8 +157,12 @@ int milenage_f2345(const uint8_t *opc, const uint8_t *k, /* f3 */ if (ck) { /* rotate by r3 = 0x20 = 4 bytes */ +#if 0 /* R1-R5 issues1153 */ for (i = 0; i < 16; i++) tmp1[(i + 12) % 16] = tmp2[i] ^ opc[i]; +#else + ShiftBits(r3, tmp1, tmp2, opc); +#endif tmp1[15] ^= 2; /* XOR c3 (= ..02) */ if (aes_128_encrypt_block(k, tmp1, ck)) return -1; @@ -150,8 +173,12 @@ int milenage_f2345(const uint8_t *opc, const uint8_t *k, /* f4 */ if (ik) { /* rotate by r4 = 0x40 = 8 bytes */ +#if 0 /* R1-R5 issues1153 */ for (i = 0; i < 16; i++) tmp1[(i + 8) % 16] = tmp2[i] ^ opc[i]; +#else + ShiftBits(r4, tmp1, tmp2, opc); +#endif tmp1[15] ^= 4; /* XOR c4 (= ..04) */ if (aes_128_encrypt_block(k, tmp1, ik)) return -1; @@ -162,8 +189,12 @@ int milenage_f2345(const uint8_t *opc, const uint8_t *k, /* f5* */ if (akstar) { /* rotate by r5 = 0x60 = 12 bytes */ +#if 0 /* R1-R5 issues1153 */ for (i = 0; i < 16; i++) tmp1[(i + 4) % 16] = tmp2[i] ^ opc[i]; +#else + ShiftBits(r5, tmp1, tmp2, opc); +#endif tmp1[15] ^= 8; /* XOR c5 (= ..08) */ if (aes_128_encrypt_block(k, tmp1, tmp1)) return -1; @@ -364,3 +395,67 @@ void milenage_opc(const uint8_t *k, const uint8_t *op, uint8_t *opc) opc[i] ^= op[i]; } } + +static void ShiftBits(uint8_t r, uint8_t rijndaelInput[16], + uint8_t temp[16], const uint8_t opc[16]) +{ + uint32_t deltlen = 16 - (r / 8); + uint32_t leftout = r % 8; + uint32_t i; + + if (leftout == 0) { + for (i = 0; i < 16; i++) { + rijndaelInput[(i+deltlen) % 16] = temp[i] ^ opc[i]; + } + } else { + uint8_t temp1[16]; + uint32_t move_bits; + uint8_t temp2; + + for (i = 0; i < 16; i++) { + temp1[(i + deltlen) % 16] = temp[i] ^ opc[i]; + } + rijndaelInput[15] = 0; + move_bits = 8 - leftout; + bits_shift(move_bits, &rijndaelInput[0], temp1, (128 - leftout)); + temp2 = temp1[0] >> (8-leftout); + rijndaelInput[15] |= temp2; + } +} + +static uint8_t *bits_shift(uint32_t bit_valid, uint8_t *dst, + uint8_t *src, uint32_t numBits) +{ + uint32_t bit_used = bit_valid; + uint32_t bit_empty = 8 - bit_used; + uint32_t numBytes = numBits >> 3; + uint32_t leftBits = numBits & 0x7; + uint32_t i = 0; + uint8_t *newDst = 0; + + for (i = 0; i < numBytes; i++) { + dst[i] = (src[i] << bit_empty) | (src[i+1] >> bit_used); + } + + if (leftBits) { + if (leftBits == bit_used) { + dst[numBytes] = src[numBytes] << bit_empty; + bit_valid = 8; + newDst = &src[numBytes+1]; + } else if (leftBits < bit_used) { + dst[numBytes] = src[numBytes] << bit_empty; + bit_valid = bit_used - leftBits; + newDst = &src[numBytes]; + } else { + dst[numBytes] = src[numBytes] << bit_empty | + (src[numBytes+1] >> bit_used); + bit_valid = 8 - (leftBits - bit_used); + newDst = &src[numBytes+1]; + } + } else { + bit_valid = bit_used; + newDst = &src[numBytes]; + } + + return newDst; +}