다시 이음

딥러닝(5) - callback 본문

AI 일별 공부 정리

딥러닝(5) - callback

Taeho(Damon) 2021. 10. 23. 17:48

안녕하세요.

 

오늘은 딥러닝 모델을 학습하면서 도움이 되는 callback에 대해서 알아보겠습니다.

 

Callback

 

Tensorflow가 기본적인 Log를 출력해주기는 하지만 훈련이 끝날 때까지 출력되는 Log만 보기에는 효율이 좋지 않습니다.

그래서 Tensorflow는 모델을 훈련시키는 동안 어떤 이벤트들이 발생하면 개발자가 원하는 동작(Callback 함수)을 수행할 수 있는 방법을 제공하고 있습니다. 

그것이 바로 Callback 입니다.

 

-tensorflow에서 제공하는 코드를 살펴보겠습니다.

import tensorflow as tf
from tensorflow import keras


#모든 콜백은 keras.callbacks.Callback을 하위 클래스화 해야 함으로 클래스 지정을 해줍니다.
class CustomCallback(keras.callbacks.Callback):
#다양한 단계에서 호출되는 메소드를 재정의 해줍니다.
    def on_train_begin(self, logs=None): #훈련이 시작될 때 호출
        keys = list(logs.keys())
        print("Starting training; got log keys: {}".format(keys))

    def on_train_end(self, logs=None): #훈련이 끝날때 호출
        keys = list(logs.keys())
        print("Stop training; got log keys: {}".format(keys))

    def on_epoch_begin(self, epoch, logs=None): #epoch가 시작될때 호출
        keys = list(logs.keys())
        print("Start epoch {} of training; got log keys: {}".format(epoch, keys))

    def on_epoch_end(self, epoch, logs=None): #epoch가 끝날 때 호출
        keys = list(logs.keys())
        print("End epoch {} of training; got log keys: {}".format(epoch, keys))

    def on_test_begin(self, logs=None):
        keys = list(logs.keys())
        print("Start testing; got log keys: {}".format(keys))

    def on_test_end(self, logs=None):
        keys = list(logs.keys())
        print("Stop testing; got log keys: {}".format(keys))

    def on_predict_begin(self, logs=None):
        keys = list(logs.keys())
        print("Start predicting; got log keys: {}".format(keys))

    def on_predict_end(self, logs=None):
        keys = list(logs.keys())
        print("Stop predicting; got log keys: {}".format(keys))

    def on_train_batch_begin(self, batch, logs=None): #훈련중 배치를 처리하기 전에 호출
        keys = list(logs.keys())
        print("...Training: start of batch {}; got log keys: {}".format(batch, keys))

    def on_train_batch_end(self, batch, logs=None):#훈련이 끝날 때 호출 여기서 log는 매트릭 결과를 포함합니다.
        keys = list(logs.keys())
        print("...Training: end of batch {}; got log keys: {}".format(batch, keys))

    def on_test_batch_begin(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Evaluating: start of batch {}; got log keys: {}".format(batch, keys))

    def on_test_batch_end(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Evaluating: end of batch {}; got log keys: {}".format(batch, keys))

    def on_predict_batch_begin(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Predicting: start of batch {}; got log keys: {}".format(batch, keys))

    def on_predict_batch_end(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Predicting: end of batch {}; got log keys: {}".format(batch, keys))

 

- callback 함수 적용하기

model = get_model()
model.fit(
    x_train,
    y_train,
    batch_size=128,
    epochs=1,
    verbose=0,
    validation_split=0.5,
    callbacks=[CustomCallback()],
)

res = model.evaluate(
    x_test, y_test, batch_size=128, verbose=0, callbacks=[CustomCallback()]
)

res = model.predict(x_test, batch_size=128, callbacks=[CustomCallback()])

 

- callback 다양하게 활용해보기

#콜백이 실행하는 곳에 loss값을 표기해줍니다.
class LossAndErrorPrintingCallback(keras.callbacks.Callback):
    def on_train_batch_end(self, batch, logs=None):
        print("For batch {}, loss is {:7.2f}.".format(batch, logs["loss"]))

    def on_test_batch_end(self, batch, logs=None):
        print("For batch {}, loss is {:7.2f}.".format(batch, logs["loss"]))

    def on_epoch_end(self, epoch, logs=None):
        print(
            "The average loss for epoch {} is {:7.2f} "
            "and mean absolute error is {:7.2f}.".format(
                epoch, logs["loss"], logs["mean_absolute_error"]
            )
        )
# epoch가 끝날 때마다 accuracy의 값이 0.9보다 높으면 훈련은 stop하는 콜백입니다.
class myCallback(tf.keras.callbacks.Callback): 
    def on_epoch_end(self, epoch, logs={}): 
        if(logs.get('accuracy') > .90):   
            self.model.stop_training = True

 

-그 외에 keras가 제공하는 callback

#ealry stopping 기능
callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3)

#callback지정을 위해 새로운 클래스를 모두 만들기 번거로울때 함수형태로 변환시켜줍니다.(LambdaCallback기능)
callback = tf.keras.callbacks.LambdaCallback(on_epoch_begin=lambda epoch, logs: print('We are starting epoch {}'.format(epoch +1 )))
model.fit(dataset, epochs=10, callbacks=[callback]) 

#LearningRateScheduler 기능
def scheduler(epoch):
    if epoch < 10:
        return 0.001
    else:
       return 0.001 * tf.math.exp(0.1 * (10 - epoch))
callback = tf.keras.callbacks.LearningRateScheduler(scheduler)

model.fit(dataset, epochs=100, callbacks=[callback])

❄️ 참고 : https://www.tensorflow.org/guide/keras/custom_callback?hl=ko