논문리뷰

[논문리뷰] Style Blind Domain Generalized Semantic Segmentation via Covariance Alignment and Semantic Consistence Contrastive Learning

minyoung lee 2024. 3. 30. 21:18

 

Style Blind Domain Generalized Semantic Segmentation via Covariance Alignment and Semantic Consistence Contrastive Learning

 

이 논문은 CVPR 2024에 accept된 논문으로, 저자는 Woo-Jin Ahn1 Geun-Yeong Yang1 Hyun-Duck Choi2* Myo-Taeg Lim1* 1Korea University 2Chonnam National University 이다.

 

논문 제목에서 알 수 있듯이 Domain Generalization Semantic Segmentation 논문이다.

 

1. Motivation

 

Synthetic datasets에서 학습된 모델은 real-world 시나리오인 상황에서 좋지 않은 성능이 보이는 경우가 많았다.

이러한 Domain Shift문제는 style factor 차이 때문이다.

 

따라서 DASS (Domain Aaptation Semantic Segmentaion) task 로 연구가 진행되었다. 

Domain Adaptation은 Source와 Target 도메인의 차이를 줄이는 것을 목표로 하지만,

Target domain이 train 시간에 접근 가능해아한다는 단점이 있다.

 

이 논문에서 task로 잡고 있는 DGSS (Domain Generalized Semantic Segmentation) task는

Source Domain으로만 train하며, domain-invariant feature를 extract하는 것을 목표로 한다.

이 DGSS의 main techniques로는 

1) Domain Randomization (DR)

2) Feature Normalization (FN)

이 있다.

 

1) Domain Randomization (DR)

DR는 training set을 Augmentation하는 방식으로, training data에 존재하는 특정 style에 overfitting되는 것을 막는 역할을 한다.

그러나 auxiliary domains에 의존한다는 단점이 있다.

 

2) Feature Normalization (FN)

FN은 마찬가지로 train data의 특정 style에 overfitting되지 않는 것을 목표로 하고 이를 위해서 feature를 regularize한다.

이를 통해서 Domain-specific style 정보를 없애며, 이런한 FN 예시로 Instance Normalization이 있다. 

그러나 FN은 semantic content와 style의 정보가 서로 entangled 되어있기 떄문에 Semantic Content 까지 없앨 수 있다는 단점이 있다. 

 

따라서 이 논문은 BlindNet을 제시하였다.

Encoder 부분에서는 Covariance aligment를 통해서 style을 Bline하였다. 

Decoder 부분에서는 Semantic consistency contrastive learning을 통해서 robustness를 향상시키도록 하였다.

 

2. Method

 

Notation은 아래와 같다.

 

 

 

 

전반적인 architecture는 아래 Figure2와 같다.

 

1) Encoder

우선 Encoder 부분을 살펴보면, 크게 Covariance Matching (CM) loss와 Cross-Covariance (CC) loss가 있다.

 

1-1) Covariance Matching (CM) 

이 Loss는 content information은 삭제하지 않고, 네트워크가 다양한 style을 uniformly하게 인식할 수 있도록 하는 것을 목표로 한다.

따라서다른 style에서의 covariance matrix 들의 차이를 줄이는 방향으로 loss를 제안한다.

 

original 이미지와 augmented 이미직 pair로 입력으로 들어가게 된다. 

이후, i번째 encoder block에서 feature를 뽑은 후, 이 feature들을 normalization한다.

normalization 된 feature들로 covariance matrices를 구한다.

 

 

두 covariance matrix의 L2 loss를 적용한다.

 

1-2) Cross-Covariance loss

 

이 CC loss에서는 cross-covariance matrix를 구한다.

이후, content information이 없어지는 것을 방지하기 위해, 위에서 구한 cross-covariance matrix의 diagonal 원소들이 1에 가까워지도록 loss를 제안하였다.

 

 

이렇게 Encoder에서는 두개의 CC, CM loss를 통해서 style-blined feature를 generate하는 것에 focus한다.

 

2) Decoder

 

Decoder파트에서는 robustness를 향상시키기 위해서 두가지 loss를 제안한다.

 

Decoder에서는 contrastive learning을 이용하고 이때 InfoNCE loss를 이용하였다.

 

 

이 논문에서 Anchor는 augmented image에서 가져오고, Positive sample은 anchor와 같은 위치의 original image에서 설정한다.

 

2-1) Class-wise Contrastive Learning (CWCL)

 

 

decoder에서 j번째 block에서 나온 feature를 이용한다.

anchor는 augmented image로 부터 나온 feature의 (m,n) 위치의 pixel이며, 

positive sample은 anchor의 위치에 대응되는 original image의 pixel이다.

negative sample은 anchor, positive sample와 다른 class를 가진 pixel들이 된다.

 

따라서 이런 anchor, positive, negative sample을 이용하여 InfoNCE loss를 구성하였다.

 

2-2) SDCL loss

추가로, Domain Shift로 인해서 비슷한 class끼리 entanglement되어있다는 문제점이 있다.

위의 그림에서도 road, sidewalk, building이 모두 엉켜져 있는 것을 알 수 있다.

 

따라서 이를 해결하기 위해서 Semantic Disentanglement Contrastive Learning (SDCL) loss를 제안하였다.

CWCL loss와 projection head를 공유한다.

 

이때 anchor는 missclasssify된 pixel이다.

Negative Sample은 augmented image feature에서 뽑아내고 anchor가 잘못 판단한 class에 대응되는 pixel들이다.

Positive Sample은 잘못 분류된 Anchor의 올바른 class이다.

 

즉, 파란색 부분처럼 잘못 분류된 pixel이 있다고 가정해볼 수 있다. 이 때 해당 pixel은 실제로는 road(보라색)으로 분류되어야하지만 Sidewalk(핑크색)으로 잘못 분류되었다.

 

이 때 negative sample들은 sidewalk에 해당하는 augmented feature의 pixel들이며,

positive sample은 옳게 분류된 road의 pixel이다. 

 

 

따라서 total loss는 제시한 loss들과 CE loss를 합쳐서 이루어진다.

 

3. Experiment

 

Architecture는 DeepLabV3+에서 ResNet-50, ShuffleNetV2, MobileNetV2 기반으로 실험하였다.

 

Dataset은 Synthetic Dataset, Real-World Dataset으로 이루어져 있다.

Synthetic Datasets은 GTAV (G), SYNTHIA (S) 이고, 

Real-World Datasets은 Cityscapes (C) BDD-100K (B), Mapillary (M)이다.

 

 

기존의 DGSS 방법과의 비교

두 가지 시나리오에 대해서 실험을 하였다.

1) trained on GTAV --> test on Cityscaeps, BDD-100K, Mapillary

2) trained on Cityscapes --> test on BDD-100K, Mapillary, SYNTHIA 

 

 

시나리오 1에 대해서는 ResNet50, ShuffleNetV2 기반인 경우 SoTA를 달성하였고,

시나리오 2에 대해서는가장 좋은 성능 또는 두번째로 좋은 성능에 달성하였다.

 

MobileNetV2에서도 SoTA를 달성하였다.

추가적으로 DIRL, DPCL에서는 External Module을 사용한 반면, 해당 논문은 External Module을 사용하지 않고도 SoTA를 달성하였다.

 

 

Qualitative 실험으로 봤을 때도 이전 연구들과 비교했을 때 더 정교하게 분류한 것을 알 수 있다.

 

Computational Cost를 비교했을 때, External Module을 사용하지 않았기 때문에 Parameter수와 GLOPS, Time에서 모두 효과적인 것을 알 수 있다.

 

 

모든 Loss를 사용했을 때 성능이 가장 좋게 나왔고, 특히 CWCL과 SDCL은 같이 사용해야 효과가 좋았다.

 

CWCL에서 sampling number를 비교해봤을 때,

diversity of classes가 증가되었을 때 성능이 향상된 것을 알 수 있었고,

negative samples의 경우에는 balanced number가 필요하다는 것을 알 수 있다.

 

SDCL loss에서 사용된 Projection Head는 CWCL loss와 각각의 projection head를 사용하는 경우, 공유하지만 Stop Gradient를 하는 경우, Shared를 한 경우를 비교해봤을 때

 

공유한 경우가 가장 좋은 성능을 보였다.

 

SCCL loss의 효과를 살펴보면, loss를 적용했을 때 class별로 잘 분류된 것을 알 수 있다.

 

 

4. Conclusion

 

해당 논문은 covariange alignment와 semantic consistency contrastive learning을 통해서 Domain Generalization task를 해결하는 새로운 방법인 BlindNet을 제시하였다.