Model debugging and Loss curve

Loss Curve에 나타난 모델의 문제점을 해석하고, 그에 따른 Debugging 방법을 정리합니다.

Interpreting Loss Curves

Ideal loss curve Interpretation
As the number of training steps increases, loss begins high, then decreases exponentially, and ultimately flattens out to reach a minimum loss.

Debugging Methods

모델 테스트 후 디버깅 과정을 Data와 Model의 관점 2가지로 나눠볼 수 있다.

  1. Data Debugging
    • Data validation with data schema : raw data에 대한 검증
      • Data schema : rules for expected statistics
    • Ensure Splits are Good Quality
      • Test/Train 데이터 통계적으로 동일한 지, 분할 비율이 일정하게 유지되는 지에 대한 검증
    • Test Engineered Data : 실제 모델의 입력 데이터에 대한 검증
  2. Model Debugging
    • Check that the data can predict the labels
      • 모델에 사용된 features가 predictive signals을 가지는 지 검증 (e.g, correlation matrices)
    • Establish a baseline.
      • 모델 개발 단계에서 간단한 heuristic 기반으로 설정한 baseline과 비교해 모델 성능 판단
    • Unit tests for ML Code to detect bugs
    • Adjust your hyperparameter values (Learning Rate , Regularization, Training epochs, Batch size, Depth/width)

Loss Curves & Debugging Methods

Loss curves & Interpretation Actions that could fix the problem described.
Model is Not Converging (the loss oscillates) : unstable training process Data Debugging
- Check if features can predict the labels
- Simplify your dataset to 10 examples that you know your model can predict on. Obtain a very low loss on the reduced dataset. Then continue debugging your model on the full dataset.

Model Debugging
- Reduce your learning rate
- Simplify your model and ensure the model outperforms your baseline. Then incrementally add complexity to the model.
An Exploding loss : The loss decreasing up to a certain number of training steps and then suddenly increasing with further training steps Data Debugging for the raw data
- Check if there are anomalous values in input data ( NaNs / Exploding gradient due to anomalous data / Division by zero / Logarithm of zero or negative numbers )

Data Debugging for the engineered data
- Check for anomalous data in the batches and in the engineered data.
- Otherwise, outlying data ⇒ shuffle the data to ensure that outliers are evenly distributed between batches.
Contradictory Metrics : Ideal loss curve, but recall is stuck at 0 Model Debugging
- Examples' classification probability is never higher than the threshold$^✔$ for positive classification. (often occurs with a large class imbalance) ⇒ Lower your classification threshold.
- Check threshold-invariant metrics (AUC).
Overfitting : Too high Testing Loss Data Debugging
- Check that the training and test splits are statistically equivalent.

Model Debugging
- Reduce model capacity.
- Add regularization.
Model Gets Stuck : Repetitive, step-like behavior of loss Data Debugging
- Check if the input data is itself exhibiting repetitive behavior ⇒ shuffle the data to remove repetitive behavior.

$^✔$ tf.keras default threshold for positive classification is 0.5


Source&Reference : Interpreting Loss Curves | Testing and Debugging in Machine Learning