2022-10-12 20:49:47 +03:00
import tqdm
2022-10-11 22:03:05 +03:00
2022-10-12 20:49:47 +03:00
class LearnScheduleIterator :
2022-10-11 22:03:05 +03:00
def __init__ ( self , learn_rate , max_steps , cur_step = 0 ) :
2022-10-12 20:49:47 +03:00
"""
2022-10-29 15:42:51 +07:00
specify learn_rate as " 0.001:100, 0.00001:1000, 1e-5:10000 " to have lr of 0.001 until step 100 , 0.00001 until 1000 , and 1e-5 until 10000
2022-10-12 20:49:47 +03:00
"""
2022-10-11 22:03:05 +03:00
pairs = learn_rate . split ( ' , ' )
self . rates = [ ]
self . it = 0
self . maxit = 0
2022-10-29 15:37:24 +07:00
try :
for i , pair in enumerate ( pairs ) :
if not pair . strip ( ) :
continue
tmp = pair . split ( ' : ' )
if len ( tmp ) == 2 :
step = int ( tmp [ 1 ] )
if step > cur_step :
self . rates . append ( ( float ( tmp [ 0 ] ) , min ( step , max_steps ) ) )
self . maxit + = 1
if step > max_steps :
return
elif step == - 1 :
self . rates . append ( ( float ( tmp [ 0 ] ) , max_steps ) )
self . maxit + = 1
2022-10-11 22:03:05 +03:00
return
2022-10-29 15:37:24 +07:00
else :
2022-10-11 22:03:05 +03:00
self . rates . append ( ( float ( tmp [ 0 ] ) , max_steps ) )
self . maxit + = 1
return
2022-10-29 15:37:24 +07:00
assert self . rates
except ( ValueError , AssertionError ) :
2022-10-29 15:42:51 +07:00
raise Exception ( ' Invalid learning rate schedule. It should be a number or, for example, like " 0.001:100, 0.00001:1000, 1e-5:10000 " to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000. ' )
2022-10-29 15:37:24 +07:00
2022-10-11 22:03:05 +03:00
def __iter__ ( self ) :
return self
def __next__ ( self ) :
if self . it < self . maxit :
self . it + = 1
return self . rates [ self . it - 1 ]
else :
raise StopIteration
2022-10-12 20:49:47 +03:00
class LearnRateScheduler :
def __init__ ( self , learn_rate , max_steps , cur_step = 0 , verbose = True ) :
self . schedules = LearnScheduleIterator ( learn_rate , max_steps , cur_step )
( self . learn_rate , self . end_step ) = next ( self . schedules )
self . verbose = verbose
if self . verbose :
print ( f ' Training at rate of { self . learn_rate } until step { self . end_step } ' )
self . finished = False
def apply ( self , optimizer , step_number ) :
2022-10-28 20:48:08 +07:00
if step_number < self . end_step :
2022-10-12 20:49:47 +03:00
return
try :
( self . learn_rate , self . end_step ) = next ( self . schedules )
except Exception :
self . finished = True
return
if self . verbose :
tqdm . tqdm . write ( f ' Training at rate of { self . learn_rate } until step { self . end_step } ' )
for pg in optimizer . param_groups :
pg [ ' lr ' ] = self . learn_rate