상세 컨텐츠

본문 제목

[Multimodal Fusion 전략 비교 실험 — MOSAIC-ST에서 어떤 Fusion이 효과적인가]

졸프

by jii 2026. 5. 8. 20:57

본문

이 글은 종양 예측 프레임워크 MOSAIC-ST 구현 과정에서 수행한 fusion strategy 비교 실험 과정을 정리하였다. 단순 concat부터 attention, similarity, gated fusion까지 네 가지 전략을 직접 비교하고, 인코더 구조와의 상호작용까지 함께 분석하였다. 

 

1. 실험 배경

멀티모달 학습에서 "어떻게 두 modality를 합치느냐"는 성능에 생각보다 큰 영향을 준다. 단순히 feature를 이어붙이는 것(concat)과, 두 modality 간 관계를 명시적으로 모델링하는 것(attention, gated) 사이에는 분명한 차이가 존재한다. MOSAIC-ST는 WSI 이미지와 Spatial Transcriptomics(ST) 데이터를 통합하는 프레임워크인데, spot 단위에서 이미지 특징과 ST 특징을 추출한 뒤 이를 하나의 표현으로 합치는 fusion 단계가 핵심이다. 이 fusion 전략을 어떻게 설계하느냐에 따라, 같은 인코더 구조에서도 성능이 크게 달라졌다. 추가로, fusion 전략의 효과는 인코더 구조에 따라 다르게 나타난다는 점도 이번 실험에서 확인하였다.

 

2. 실험 설정

본격적인 실험에 앞서 전체 모델 구조를 살펴보면, 다음과 같다. 

 

실험에서는 Fusion Module 부분만 바꾸면서, 그리고 동시에 ST Encoder 구조도 두 가지 버전으로 바꾸면서 비교하였다. 

 

-인코더 구조 비교

 

Fusion 전략 실험을 인코더 구조와 함께 비교한 이유는, 두 요소가 독립적이지 않기 때문이다. ST 정보를 어떻게 인코딩하느냐에 따라 fusion 단계에서 요구되는 것이 달라진다.

 

Ver1 — Triple-Encoder

 

유전자 발현 정보와 공간 좌표 정보를 완전히 분리된 두 개의 인코더로 처리한다. 결과적으로 이미지, 유전자, 좌표 세 가지 표현이 fusion 단계에 들어오게 된다. 

class TripleEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.img_encoder      = ImageEncoder()       # ResNet-18
        self.gene_encoder     = GeneEncoder()        # scBERT-inspired Performer
        self.spatial_encoder  = SpatialEncoder()     # Transformer (2D coord → 256-dim)

    def forward(self, image, gene_expr, spatial_coord):
        v  = self.img_encoder(image)                 # (N, 256)
        h_gene = self.gene_encoder(gene_expr)        # (N, 256)
        h_sp   = self.spatial_encoder(spatial_coord) # (N, 256)
        return v, h_gene, h_sp

 

유전자 인코더는 scBERT 방식에 착안해서 유전자를 토큰으로 다루는 Performer 구조를 사용한다. 발현값을 7개 구간으로 이산화해서 lookup table로 임베딩한다.

class GeneEncoder(nn.Module):
    def __init__(self, n_genes=2000, top_k=512, d_model=256):
        super().__init__()
        self.top_k = top_k
        self.bin_emb = nn.Embedding(7, d_model)   # 발현값 이산화: 0~6 구간
        self.id_emb  = nn.Embedding(n_genes, d_model)
        self.cls_token = nn.Parameter(torch.randn(1, 1, d_model))
        self.performer = PerformerEncoder(depth=6, heads=8, dim=d_model)
        self.proj = nn.Linear(d_model, d_model)

    def forward(self, gene_expr):
        topk_vals, topk_idx = gene_expr.topk(self.top_k, dim=-1)

        # 발현값 이산화
        bins = torch.bucketize(
            torch.log1p(topk_vals),
            boundaries=torch.linspace(0, 5, 6)
        ).clamp(0, 6)                              # (N, K)

        tokens = self.id_emb(topk_idx) + self.bin_emb(bins)   # (N, K, D)
        cls = self.cls_token.expand(tokens.size(0), -1, -1)   # (N, 1, D)
        seq = torch.cat([cls, tokens], dim=1)                  # (N, K+1, D)

        out = self.performer(seq)
        return self.proj(out[:, 0])                # CLS pooling → (N, D)

 

Ver2 — Dual-Stream (Spatial-Aware ST Encoder)

 

공간 좌표를 독립된 인코더로 분리하지 않고, 유전자 인코딩 과정 내부에 공간 정보를 통합한다. spatial token과 gene token이 함께 Transformer를 통과하면서 self-attention을 공유한다.

 

class SpatialAwareSTEncoder(nn.Module):
    def __init__(self, n_genes=2000, top_k=512, d_model=256):
        super().__init__()
        self.top_k = top_k
        self.gene_id_emb  = nn.Embedding(n_genes, d_model)
        self.gene_pos_emb = nn.Embedding(top_k, d_model)
        self.expr_proj    = nn.Sequential(
            nn.Linear(1, d_model), nn.LayerNorm(d_model), nn.GELU()
        )
        self.spatial_proj = nn.Linear(2, d_model)   # (x, y) → 256-dim

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=8, batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=4)
        self.out_proj = nn.Linear(d_model, d_model)

    def forward(self, gene_expr, spatial_coord):
        topk_vals, topk_idx = gene_expr.topk(self.top_k, dim=-1)

        # gene token = identity + position + expression value
        gene_tokens = (
            self.gene_id_emb(topk_idx)                            # (N, K, D)
            + self.gene_pos_emb(torch.arange(self.top_k).to(topk_idx.device))
            + self.expr_proj(topk_vals.unsqueeze(-1))
        )

        # spatial token
        s_token = self.spatial_proj(spatial_coord).unsqueeze(1)   # (N, 1, D)

        # [spatial | gene] 통합 후 Transformer
        seq = torch.cat([s_token, gene_tokens], dim=1)            # (N, K+1, D)
        out = self.transformer(seq)                                # (N, K+1, D)

        # spatial token을 query로, gene token들과 cross-attention
        q = out[:, 0:1, :]    # (N, 1, D)
        k = v = out[:, 1:, :] # (N, K, D)
        d = q.size(-1)
        attn = torch.softmax(q @ k.transpose(-2, -1) / d**0.5, dim=-1)
        h_st = (attn @ v).squeeze(1)                              # (N, D)

        return self.out_proj(h_st)

 

결국 두 구조의 핵심적인 차이는 공간 정보를 독립 신호로 다루느냐, 유전자 표현을 해석하는 조건 정보로 다루느냐이다.

 

-Fusion 전략 비교

 

이는 각 spot에서 이미지 특징 v와 ST 특징 s를 받아 통합된 표현 z를 만드는 방식에 관한 것이다. 

 

1. Concatenation-based MLP

class ConcatFusion(nn.Module):
    def __init__(self, d=256):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(d * 2, d),
            nn.LayerNorm(d),
            nn.ReLU(),
            nn.Linear(d, d)
        )

    def forward(self, v, s):
        return self.mlp(torch.cat([v, s], dim=-1))

 

가장 단순한 방법이다. 두 표현을 이어붙인 뒤 MLP로 압축한다. 단순하지만 두 modality 간 명시적 관계 모델링은 없다.

 

2. Attention-based Fusion

class AttentionFusion(nn.Module):
    def __init__(self, d=256, n_heads=4):
        super().__init__()
        self.cross_attn = nn.MultiheadAttention(d, n_heads, batch_first=True)
        self.norm = nn.LayerNorm(d)
        self.proj = nn.Linear(d, d)

    def forward(self, v, s):
        # v를 query, s를 key/value로 — 이미지가 ST에서 무엇을 참조할지 학습
        v_ = v.unsqueeze(1)   # (N, 1, D)
        s_ = s.unsqueeze(1)   # (N, 1, D)
        out, _ = self.cross_attn(v_, s_, s_)
        out = self.norm(out.squeeze(1) + v)   # residual
        return self.proj(out)

 

이미지 표현이 ST 표현을 참조하는 방향의 cross-attention이다. 두 modality 간 명시적 상호작용을 모델링한다.

 

3. Similarity-based Fusion

class SimilarityFusion(nn.Module):
    def __init__(self, d=256):
        super().__init__()
        self.proj = nn.Sequential(
            nn.Linear(d * 3, d),
            nn.LayerNorm(d),
            nn.ReLU(),
            nn.Linear(d, d)
        )

    def forward(self, v, s):
        elem_prod = v * s                                           # element-wise product
        abs_diff  = (v - s).abs()                                  # absolute difference
        cos_sim   = F.cosine_similarity(v, s, dim=-1, eps=1e-8)   # (N,)
        cos_sim   = cos_sim.unsqueeze(-1).expand_as(v)            # (N, D)

        combined = torch.cat([elem_prod, abs_diff, cos_sim], dim=-1)  # (N, 3D)
        return self.proj(combined)

 

두 표현 사이의 유사도를 다양한 방식으로 계산해서 조합한다. 두 modality가 얼마나 일치하는지를 명시적으로 포착하려는 방식이다.

 

4. Gated Fusion

class GatedFusion(nn.Module):
    def __init__(self, d=256):
        super().__init__()
        self.gate_v = nn.Sequential(nn.Linear(d * 2, d), nn.Sigmoid())
        self.gate_s = nn.Sequential(nn.Linear(d * 2, d), nn.Sigmoid())
        self.proj   = nn.Linear(d, d)

    def forward(self, v, s):
        combined = torch.cat([v, s], dim=-1)
        g_v = self.gate_v(combined)   # 이미지 특징의 기여도
        g_s = self.gate_s(combined)   # ST 특징의 기여도
        z   = g_v * v + g_s * s
        return self.proj(z)

 

각 modality의 기여도를 입력에 따라 동적으로 조절하는 방식이다. 샘플마다 이미지와 ST 중 어느 쪽이 더 신뢰할 만한지를 학습한다.

 

3. 실험 결과

 

-분석

 

1. 인코더 구조가 fusion보다 더 중요하다.

 

가장 먼저 눈에 띄는 것은 Ver1 (Triple-Encoder)의 전반적인 성능 저하이다. 어떤 fusion 전략을 사용해도 Ver0, Ver2 대비 낮은 성능을 보였다. Triple-Encoder는 유전자 발현과 공간 좌표를 독립된 인코더로 분리해서 처리하는 구조이다. 직관적으로는 명확한 분리가 효과적일 것 같지만, 결과는 반대였다. 제한된 데이터 환경에서 세 modality 간 상호작용이 오직 fusion 단계에서만 이루어지다 보니, 표현 학습 자체가 충분히 이루어지기 어려웠던 것으로 보인다.

 

Ver2는 공간 정보를 유전자 인코딩 과정에 일찍 통합하는 early integration 구조이다. Transformer 내에서 gene token과 spatial token이 self-attention을 공유하기 때문에, 인코딩 단계에서 이미 공간 맥락이 반영된 ST 표현이 만들어진다. 이는 fusion 단계의 부담이 줄어드는 효과가 있다.

 

2. Ver2에서는 Concat이 가장 효과적이다

Ver2 + Concat이 84.62%로 전체 최고 성능이다. Attention fusion은 Ver0에서는 82.05%로 두 번째로 좋은 성능을 보였지만, Ver2에서는 76.92%로 오히려 성능이 떨어졌다. Ver2의 ST 인코더가 이미 spatial token을 query로 gene attention을 수행하고 있기 때문에, fusion 단계에서 다시 cross-attention을 적용하면 중복이 발생하거나 학습이 불안정해지는 것으로 해석할 수 있다. Gated fusion은 두 인코더 구조 모두에서 안정적인 편이다(74.36% ~ 80.77%). 특정 전략에 크게 의존하지 않는 대신, 최고 성능도 내지 못하는 경향이 있다.

 

3. Fusion 전략의 효과는 인코더 구조에 따라 달라진다

같은 fusion 전략이라도 인코더에 따라 결과가 역전된다. Attention은 Ver0에서 82.05%, Ver2에서 76.92%이고, Concat은 Ver0에서 80.77%, Ver2에서 84.62%이다. 이는 fusion 전략을 선택할 때 인코더 구조와의 상호작용을 반드시 함께 고려해야 한다는 점을 시사한다. “어떤 fusion이 일반적으로 좋다”는 식의 결론을 내리기 어렵고, 인코더가 어떤 표현을 만들어내느냐에 따라 최적의 fusion이 달라진다.

 

4. 결론

단순히 "어떤 fusion이 좋다"가 아니라, 인코더 설계와 fusion 전략이 함께 맞물려야 한다는 점이 이번 실험의 가장 중요한 takeaway이다. 멀티모달 모델을 설계할 때 fusion 전략을 독립적으로 ablation하기보다, 인코더 구조와 세트로 비교하는 것을 권장한다.

'졸프' 카테고리의 다른 글

주제 관련 논문 리딩 및 데이터셋 조사  (0) 2025.11.27
방향성 및 데이터셋 고민  (0) 2025.09.27

관련글 더보기

댓글 영역