논문리뷰

[논문리뷰] PromptStyler: Prompt-driven Style Generationfor Source-free Domain Generalization

minyoung lee 2024. 2. 25. 22:15

이 논문은 ICCV 2023에 accept된 논문으로 저자는 Junhyeong Cho1 Gilhyun Nam1 Sungyeon Kim2 Hunmin Yang1,3 Suha Kwak2 (1ADD 2 POSTECH 3KAIST)이다.

https://promptstyler.github.io/

 

PromptStyler

Figure 3. PromptStyler learns diverse style word vectors which do not distort content information of style-content prompts. After learning style word vectors, we synthesize style-content features (e.g., from "a S1 style of a dog") via a pre-trained text en

promptstyler.github.io

https://arxiv.org/abs/2307.15199

 

PromptStyler: Prompt-driven Style Generation for Source-free Domain Generalization

In a joint vision-language space, a text feature (e.g., from "a photo of a dog") could effectively represent its relevant image features (e.g., from dog photos). Also, a recent study has demonstrated the cross-modal transferability phenomenon of this joint

arxiv.org

 

1. Motivation

이 논문은 domain generalization에 관한 논문이다.

다양한 distribution shifts를 joint space에서 simulate하기 위해서 다양한 style을 생성해내었다. 이 때 어떠한 scoure domain image없이 오직 prompt를 통해서 domain generalization을 다루었다.

 

저자들은 질문을 하나 던졌다.

"source domain을 사용하지 않고도 large-scale 모델의 latent space에서 다양한 domain shfit를 simulate하여 domain generalization 성능을 향상시킬 수 있는가?"

 

저자들은 이미지의 shared style이 domain을 특정지을 수 있고, 이러한 share style은 learnable word vector를 통해서 capture될 수 있다고 하였다. 이 때 learnable word vector는 CLIP을 이용한 pseudo word의 learnable vector를 의미한다.

 

 

그림은 이 논문의 motivation을 보여준다. 어떠한 이미지도 없이 다양한  style을 synthesiz하는 것을 기대한다.

 

 

다양한 distribution shfits를 효과적으로 simulate하기 위해서 style diversity를 maximize하였다. 이는 Figure 2의 왼쪽 Style Diversity 파트이다. 밑에 쪽에 모여있었던 style이 Style Diversity Loss가 적용된 이후에는 다양하게 퍼져있는 것을 확인할 수 있다.

 

또한 style뿐만 아니라 content consistency도 고려해야한다. content information을 왜곡시키지 말아야하기 때문이다.

따라서 style-content prompt로부터 style-content feature를 얻는다. 이 때 style-content prompt ("a S* style of a [class]")는 각각 content prompt ("class")와 가까이 있도록 한다. 즉, "cat"과 "a S1 style of a cat"이 가까워지도록 한다.

 

이렇게 얻어진 learned style word vectors는 classifier를 학습시키기 위해서 style-content features로 이용된다.

이 synthesized features는 다양한 unknown styles로 known contents를 simulate할 수 있다.

 

Text encoder는 training시에만 사용하고, Image encoder는 inference때만 사용한다.

 

Contributions

1. vision-language space에서 prompt를 통해서 다양한 style을 synthesize하고 이를 통해서 sorce-free domain generalization을 효과적으로 이루었다.

2. known content를 다양한 unknown styles로 image를 simulate하는 효과적인 새로운 method를 제시하였다.

3. 어떠한 이미지도 사용하지 않고 Domain Generalization (DG) 밴치마크에서 SOTA를 달성하였다.

 

2. Method

다양한 style을 나타나기 위해서  style word vector를 학습하였다. 이후에 이 synthesized style-content feature를 이용해서 linear classifier를 학습하였다. 이 때 이 feature는 pre-treinad text encoder인 T를 이용하였다.

 

inference 시간에는 pre-trained image encoder I를 사용하여 input 이미지에 대한 image feature를 추출하였고, 이후에 학습된 linear classifier를 이용하였다.

 

2-1. Prompt-driven style generation

3가지 종류의 prompt를 사용하였다.

 

이 때 $s_i$는 처음에는 random하게 initialized되었다.

 

word vector를 학습하기 위해서 두가지의 design choice가 있었다.

1. sequential하게 각 style word vector를 학습하는 방법

2. parallel하게 모든 style word vector를 학습하는 방법

저자들은 1번째 방법이 더 적은 메모리를 사용하기 때문에 첫번째 방법을 사용하였다고 했다.

 

Style divesity loss

 

K styles의 다양성을 최대화하기 위해서 Style diversity loss를 적용하였다. 

 

 

이는 각 i번째 style feature에 대해 다른 존재하는 style feature들에 대한 consine similarity의 absolute value를 최소화시킨다. 결국 최대한 비슷함을 줄여서 style을 다양하게 해주었다.

 

Content consistency loss

 

style diversity loss만 이용해서 style word vector를 학습할 경우 undesirable outcome이 나올 수 있다.

style-content feature를 생성하는 과정에서 learned style은 content information을 왜곡시킬 수 있기 때문이다. 따라서 content consistency loss를 이용하였다.

 

각각의 style-content feature는 i-th style word vector가 content feature와 가장 큰 consine similarity score를 가지도록 synthesized된다.

 

i번째 style word vector에 대해 m번째 class에 대한 style-content feature와 n번째 class에 대한 content feature의 cosine similarity를 계산한다. 

이후, 각각의 style-content feature가 각 content feature에 맞추어서 가까이 있도록 하는 역할을 한다. 

결과적으로 이는 content information을 유지하기 위함이고, content가 cat이면 cat끼리 가까워지게 된다.

 

최종 loss는 style loss와 content loss가 더해져서 이루어진다.

 

 

2-2. Training a linear classifier using diverse styles

K개의 style word vector를 학습하고나서 linear classifier를 학습시키기 위해서 KN style content feature를 생성한다. 이는 K개의 style과 pre-defined된 N개의 class를 text encoder T를 통해서 synthesize된 것이다.

 

이 때 classification loss로 ArcFace loss를 사용하였다. 이 loss는 다른 class 에서 온 feature들을 더 멀리 있도록 하여 discriminative하게 한다.

 

2-3. Inference using the trained classifier

classifier는 pre-trained image encoder I를 infernece time 때 사용하였다.

 

input image x를 iamge encoder에 넣어서 image feature를 얻고, 학습된 classifier는 class score를 l2 noramzliaed image feature를 통해서 생성해내게 된다.

 

3. Experiments

3-1. Evaluation datasets

- PACS : 4 domains, 7 classes

- VLSC : 4 domains, 5 classes

- OfficeHome : 4 domains, 65 classes

- DomainNet : 6 domains, 345 classes

 

3-2. Implementatino details

- 대규모 pre-trained vision-language model로 CLIP을 사용하였다.

- text encoder T : transformer

- image encoder I : ResNet-50

 

- style word vector 학습은 기존의 prmopt learning method를 따랐다고 한다.

 

3-3. Evaluations

 

기존의 DG 밴치마크와 비교해보았을 때 PromptStyler는 Source Domain, Domain Descriptioin 모두 사용하지 않았는데도 높은 성능을 보였다.

 

 

t-SNE 결과이다. style loss를 적용했을 때는 다른 색(style)이 서로 퍼져있는 것을 확인할 수 있고, content loss를 적용했을 때는 같은 shape끼리 모여있는 것을 볼 수 있다.

style과 content loss를 같이 적용했을 때 shape끼리 모여있고 그 안에서의 style들은 다양하게 분포한 것을 알 수 있다.

 

정량적인 결과로도 style loss와 content loss를 모두 사용한 경우가 accuracy가 가장 높은 것을 알 수 있다.

 

style-content feature를 통한 결과를 살펴보면, cat이라는 content를 왜곡시키지 않으며 다양한 style이 적용된 것을 알 수 있다.

 

 

classifiaction loss로 Softmax를 사용했을 때보다 ArcFace 사용한 결과가 더 좋은 것을 알 수 있다.

 

4. Conclusion

CLIP을 이용하여 prompt를 통해 다양한 style을 생성해내고 source-free domain generalization에서 SOTA 성능을 달성하였다.

 

 

5. Discussion

CLIP에 대한 영향이 어느정도 미치는지에 대한 결과와 prompt를 다르게 했을 때의 성능이 궁금하다.