Generalized MT19937 PRG reverse analysis
Problem Description
Given 10 parameters
Approach Analysis
20pts
Brute force enumeration of the seed is sufficient. Code omitted.
100pts
The general approach is: generated random numbers → state
after twist
→ possible state[-1]
before twist
→ obtain the seed.
We first consider the standard case, where the parameters are exactly the same as in the paper. There is an existing solution available.
The specific principles will not be elaborated here, as the provided link and Mivik's solution explain it clearly.
Then, when the highest bit of tmp
was odd based on the highest bit of tmp
.
The code for reversing the twist
in the original solution is as follows:
def backtrace(cur):
high = 0x80000000
low = 0x7fffffff
mask = 0x9908b0df
state = cur
for i in range(623,-1,-1):
tmp = state[i]^state[(i+397)%624]
# recover Y,tmp = Y
if tmp & high == high:
tmp ^= mask
tmp <<= 1
tmp |= 1
else:
tmp <<=1
# recover highest bit
res = tmp&high
# recover other 31 bits,when i =0,it just use the method again it so beautiful!!!!
tmp = state[i-1]^state[(i+396)%624]
# recover Y,tmp = Y
if tmp & high == high:
tmp ^= mask
tmp <<= 1
tmp |= 1
else:
tmp <<=1
res |= (tmp)&low
state[i] = res
return state
The issue here lies in the condition if tmp & high == high:
, which is no longer valid and cannot be used to accurately determine the value of tmp.
A straightforward approach is to enumerate the four possible values of mt[N - 1]
. Since mt[i - 1] ^ mt[i - 1] >> 30
is reversible and 2 ** 32
. Therefore, we can backtrack the seed by finding the inverse of
To verify these four seeds, we can use each seed to regenerate a few random numbers and compare them with the input.
Complete Code
from gmpy2 import invert
def _int32(x):
return int(0xFFFFFFFF & x)
class mt19937:
def __init__(self, seed=0):# magic method (run code below automatically when an object is created)
self.mt = [0] * N
self.mt[0] = seed
self.mti = 0
for i in range(1, N):
self.mt[i] = _int32(F * (self.mt[i - 1] ^ self.mt[i - 1] >> 30) + i)
def getstate(self,op=False):
if self.mti == 0 and op==False:
self.twist()
y = self.mt[self.mti]
y = y ^ y >> U
y = y ^ y << S & B
y = y ^ y << T & C
y = y ^ y >> L
self.mti = (self.mti + 1) % N
return _int32(y)
def twist(self):
for i in range(0, N):
y = _int32((self.mt[i] & 0x80000000) + (self.mt[(i + 1) % N] & 0x7fffffff))
self.mt[i] = (y >> 1) ^ self.mt[(i + M) % N]
if y % 2 != 0:
self.mt[i] = self.mt[i] ^ A
def inverse_right(self,res, shift, mask=0xffffffff, bits=32):
tmp = res
for i in range(bits // shift):
tmp = res ^ tmp >> shift & mask
return tmp
def inverse_left(self,res, shift, mask=0xffffffff, bits=32):
tmp = res
for i in range(bits // shift):
tmp = res ^ tmp << shift & mask
return tmp
def extract_number(self,y): # namely "temper" in Mivik's code
y = y ^ y >> U
y = y ^ y << S & B
y = y ^ y << T & C
y = y ^ y >> L
return y&0xffffffff
def recover(self,y): # inverse of extract_number
y = self.inverse_right(y,L)
y = self.inverse_left(y,T,C)
y = self.inverse_left(y,S,B)
y = self.inverse_right(y,U)
return y&0xffffffff
def setstate(self,s): # N generated random numbers -> mt[] after twisting
if(len(s)!=N):
raise ValueError("The length of prediction must be N!")
for i in range(N):
self.mt[i]=self.recover(s[i])
#self.mt=s
self.mti=0
'''
def predict(self,s): # a method to predict other pseudo random numbers after given N of them (useless in this problem)
self.setstate(s)
self.twist()
return self.getstate(True)
'''
def invtwist(self): # mt[] after twisting -> 4 possible values of mt[-1] before twisting
high = 0x80000000
low = 0x7fffffff
mask = A
opt = [0] * 4
for i in range(N-1,N-2,-1): # only process the last number
for s in range(2):
for t in range(2):
tmp = self.mt[i]^self.mt[(i+M)%N]
if s==0: # two possibilities
tmp ^= mask
tmp <<= 1
tmp |= 1
else:
tmp <<=1
res = tmp&high
tmp = self.mt[i-1]^self.mt[(i+M-1)%N]
if t==0: # another two
tmp ^= mask
tmp <<= 1
tmp |= 1
else:
tmp <<=1
res |= (tmp)&low
opt[s * 2 + t] = res
return opt
def recover_seed(self,last): # mt[-1] -> mt[0]
n = 1 << 32
inv = invert(F, n) # inverse of F mod 2 ^ 32
for i in range(N-1, 0, -1):
last = ((last - i) * inv) % n
last = self.inverse_right(last, 30)
return last
N, M, A, U, S, B, T, C, L, F = map(int, input().split())
inpt = [0] * N # align enough space
for i in range (N):
inpt[i] = int(input())
D = mt19937()
D.setstate(inpt) # using the input to recover state after twisting
op = D.invtwist() # generate four possibilities of D.mt[-1]
seed = [0] * 4
for i in range(4): # check the seeds one by one
seed[i] = D.recover_seed(op[i])
E = mt19937(seed[i])
E.getstate() # another psuedo random number generator
flag = 1
for j in range(10): # compare first 10 numbers is totally enough
if E.extract_number(E.mt[j]) != inpt[j]:
flag = 0
if flag > 0:
print(seed[i])
break
Time complexity: