Study/NLP

[딥러닝] LLaVA 모델 정리(멀티 모달 LLM, LLaMA + CLIP(Contrastive Language–Image Pretraining)-ViT)

Railly Linker 2025. 6. 5. 18:42

- 이번 포스팅에서는 멀티 모달 LLaVA (Large Language and Vision Assistant) 모델을 정리하겠습니다.

이전 게시글에서 Langchain 으로 사용해 보았던 바로 그 모델인데, 이번에는 단순한 모델의 사용 방법이 아닌 멀티 모달 LLM 모델의 내부 구조에서 원리까지를 자세히 다루어볼 것입니다.

 

- 현 시점 오픈소스 최고 성능 모델은 2024 년 12월에 출시된 DeepSeek-VL2 모델입니다.

제가 굳이 2023 년 4 월에 출시된 LLaVA 모델을 리뷰하는 이유는, LLaVA 모델이 기존에 제가 정리한 LLaMA 모델을 기반으로 나온 모델이며, GPT-4o 에 필적하는 DeepSeek-VL2 모델이 기존 멀티 모달 모델과 어떠한 차이점으로 성능을 높인 것인지에 대해 비교하며 알 수 있는 가장 좋은 예시라고 생각했기 때문입니다.

 

- 멀티 모달 모델은 위 게시글 링크에서 설명했듯, 기본적으로는 이미지 인코더 + LLM 모델 의 구조입니다.

이미지에서 의미를 추출하는 이미지 인코딩 기술인 ViTCNN 과 같은 기술을 이해하셔야 하며,

이렇게 추출된 이미지의 의미 벡터를, 토큰 임베딩으로 추출한 텍스트의 의미 벡터와 결합하여 LLM 에 입력해주는 것으로 간단히 설명이 가능합니다.

 

LLaVA 의 멀티모달은 이미지 인코더 ViT 에 더하여 LLM 모델은 LLaMA 를 사용하므로, LLaMA 정리글을 참고하세요.

 

(CLIP (Contrastive Language–Image Pretraining) )

- LLaVA 모델은 이미지 인코딩 부분에 CLIP-ViT 를 사용하며, LLM 모델에는 LLaMA 를 사용합니다.

LLaMA 모델은 이미 이전에 정리했으므로 생략하고, 먼저 CLIP-ViT 에 대해 설명하겠습니다.

 

- CLIP-ViT는 OpenAI가 발표한 CLIP (Contrastive Language–Image Pretraining) 모델에서 사용된 Vision Transformer (ViT) 아키텍처입니다.

쉽게 말하면, CLIP 은 텍스트와 이미지를 동시에 학습하는 멀티모달 모델이고,

이 모델의 백본으로 사용된 것이 ViT 로, 즉 LLaVA 에서는 CLIP 에서 사전 학습된 ViT 를 사용한다는 것입니다.(백본은 중요치 않습니다. CLIP 의 핵심은 이미지와 텍스트간 유사도를 학습하는 구조를 의미합니다.)

 

아래 CLIP 모델 코드를 보고 설명하겠습니다.

 

- CLIP 코드

import torch
import torch.nn as nn
import torch.nn.functional as F


class PatchEmbedding(nn.Module):
    def __init__(self, image_size, patch_size, in_channels=3, embed_dim=768):
        super().__init__()
        assert image_size % patch_size == 0, "image_size must be divisible by patch_size"
        self.image_size = image_size
        self.patch_size = patch_size
        self.num_patches = (image_size // patch_size) ** 2

        self.proj = nn.Conv2d(
            in_channels,
            embed_dim,
            kernel_size=patch_size,
            stride=patch_size
        )

        nn.init.normal_(self.proj.weight, std=0.02)
        if self.proj.bias is not None:
            nn.init.zeros_(self.proj.bias)

    def forward(self, x):
        x = self.proj(x)  # [B, embed_dim, H/patch, W/patch]
        x = x.flatten(2)  # [B, embed_dim, num_patches]
        x = x.transpose(1, 2)  # [B, num_patches, embed_dim]
        return x


class TransformerEncoderLayer(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_ratio=4.0, dropout=0.0):
        super().__init__()
        self.embed_dim = embed_dim

        self.norm1 = nn.LayerNorm(embed_dim, eps=1e-5)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=False)

        self.norm2 = nn.LayerNorm(embed_dim, eps=1e-5)
        mlp_hidden_dim = int(embed_dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, mlp_hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_hidden_dim, embed_dim),
            nn.Dropout(dropout),
        )

        nn.init.normal_(self.mlp[0].weight, std=0.02)
        nn.init.zeros_(self.mlp[0].bias)
        nn.init.normal_(self.mlp[3].weight, std=0.02)
        nn.init.zeros_(self.mlp[3].bias)

        nn.init.normal_(self.attn.in_proj_weight, std=0.02)
        if self.attn.in_proj_bias is not None:
            nn.init.zeros_(self.attn.in_proj_bias)
        nn.init.normal_(self.attn.out_proj.weight, std=0.02)
        if self.attn.out_proj.bias is not None:
            nn.init.zeros_(self.attn.out_proj.bias)

    def forward(self, x, attn_mask=None, key_padding_mask=None):
        residual = x
        x_norm = self.norm1(x)
        x_attn, _ = self.attn(
            x_norm, x_norm, x_norm,
            attn_mask=attn_mask,  # causal mask 등에 사용
            key_padding_mask=key_padding_mask  # padding mask 사용
        )
        x = x_attn + residual

        residual2 = x
        x_norm2 = self.norm2(x)
        x_mlp = self.mlp(x_norm2)
        x = x_mlp + residual2

        return x


class VisionTransformer(nn.Module):
    def __init__(
            self,
            image_size=224,
            patch_size=16,
            in_channels=3,
            embed_dim=768,
            depth=12,
            num_heads=12,
            mlp_ratio=4.0,
            dropout=0.0
    ):
        super().__init__()
        self.patch_embed = PatchEmbedding(image_size, patch_size, in_channels, embed_dim)
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        nn.init.normal_(self.cls_token, std=0.02)
        nn.init.normal_(self.pos_embed, std=0.02)

        self.pos_dropout = nn.Dropout(dropout)

        self.layers = nn.ModuleList([
            TransformerEncoderLayer(embed_dim, num_heads, mlp_ratio, dropout)
            for _ in range(depth)
        ])

        self.norm = nn.LayerNorm(embed_dim, eps=1e-5)

    def forward(self, x):
        B = x.size(0)
        x = self.patch_embed(x)  # [B, num_patches, embed_dim]

        cls_tokens = self.cls_token.expand(B, -1, -1)  # [B, 1, embed_dim]
        x = torch.cat((cls_tokens, x), dim=1)  # [B, num_patches+1, embed_dim]
        x = x + self.pos_embed  # [B, seq_len, embed_dim]
        x = self.pos_dropout(x)

        x = x.transpose(0, 1)  # [seq_len, B, embed_dim]

        for layer in self.layers:
            x = layer(x, attn_mask=None, key_padding_mask=None)

        x = x.transpose(0, 1)  # [B, seq_len, embed_dim]
        x = self.norm(x)  # [B, seq_len, embed_dim]

        return x[:, 0, :]  # [B, embed_dim]


class TextTransformer(nn.Module):
    def __init__(
            self,
            vocab_size,
            max_length=77,
            embed_dim=512,
            depth=12,
            num_heads=8,
            mlp_ratio=4.0,
            dropout=0.0,
            pad_token_id=0,
            eos_token_id=None
    ):
        super().__init__()
        self.embed_dim = embed_dim
        self.max_length = max_length
        self.pad_token_id = pad_token_id
        self.eos_token_id = eos_token_id if eos_token_id is not None else vocab_size - 1

        self.token_embed = nn.Embedding(vocab_size, embed_dim)
        nn.init.normal_(self.token_embed.weight, std=0.02)

        self.pos_embed = nn.Parameter(torch.zeros(1, max_length, embed_dim))
        nn.init.normal_(self.pos_embed, std=0.02)

        self.pos_dropout = nn.Dropout(dropout)

        self.layers = nn.ModuleList([
            TransformerEncoderLayer(embed_dim, num_heads, mlp_ratio, dropout)
            for _ in range(depth)
        ])

        self.norm = nn.LayerNorm(embed_dim, eps=1e-5)

        self.register_buffer('_default_eos_tensor', torch.tensor(max_length - 1, dtype=torch.long))

    def forward(self, tokens):
        B, L = tokens.size()
        assert L <= self.max_length, "입력 시퀀스 길이가 max_length를 초과했습니다."

        x = self.token_embed(tokens)  # [B, L, embed_dim]
        x = x + self.pos_embed[:, :L, :]  # [B, L, embed_dim]
        x = self.pos_dropout(x)

        x = x.transpose(0, 1)  # [L, B, embed_dim]

        causal_mask = torch.triu(torch.ones(L, L, device=x.device, dtype=torch.bool), diagonal=1)

        pad_mask = tokens.eq(self.pad_token_id)  # [B, L]

        for layer in self.layers:
            x = layer(
                x,
                attn_mask=causal_mask,  # 미래 마스킹
                key_padding_mask=pad_mask  # 패딩 토큰 마스킹
            )

        x = x.transpose(0, 1)  # [B, L, embed_dim]
        x = self.norm(x)  # [B, L, embed_dim]

        eos_mask = tokens.eq(self.eos_token_id)  # [B, L] boolean
        eos_positions = torch.where(
            eos_mask.any(dim=1),
            eos_mask.float().argmax(dim=1),
            self._default_eos_tensor.expand(B).to(x.device)
        )  # shape: [B]

        output = x[torch.arange(B, device=x.device), eos_positions, :]  # [B, embed_dim]
        return output


class CLIP(nn.Module):
    def __init__(
            self,
            image_size=224,
            patch_size=16,
            in_channels=3,
            vision_embed_dim=768,
            vision_depth=12,
            vision_heads=12,
            vision_mlp_ratio=4.0,
            text_vocab_size=49408,
            text_max_length=77,
            text_embed_dim=512,
            text_depth=12,
            text_heads=8,
            text_mlp_ratio=4.0,
            embed_dim=512,
            dropout=0.0,
            pad_token_id=0,
            eos_token_id=None
    ):
        super().__init__()
        self.visual = VisionTransformer(
            image_size=image_size,
            patch_size=patch_size,
            in_channels=in_channels,
            embed_dim=vision_embed_dim,
            depth=vision_depth,
            num_heads=vision_heads,
            mlp_ratio=vision_mlp_ratio,
            dropout=dropout
        )
        self.visual_proj = nn.Linear(vision_embed_dim, embed_dim, bias=False)
        nn.init.normal_(self.visual_proj.weight, std=0.02)

        self.textual = TextTransformer(
            vocab_size=text_vocab_size,
            max_length=text_max_length,
            embed_dim=text_embed_dim,
            depth=text_depth,
            num_heads=text_heads,
            mlp_ratio=text_mlp_ratio,
            dropout=dropout,
            pad_token_id=pad_token_id,
            eos_token_id=eos_token_id
        )
        self.text_proj = nn.Linear(text_embed_dim, embed_dim, bias=False)
        nn.init.normal_(self.text_proj.weight, std=0.02)

        self.logit_scale = nn.Parameter(torch.ones([]) * torch.log(torch.tensor(1 / 0.07)))

    def encode_image(self, image):
        image_features = self.visual(image)  # [B, vision_embed_dim]
        image_embeds = self.visual_proj(image_features)  # [B, embed_dim]
        return image_embeds

    def encode_text(self, text):
        text_features = self.textual(text)  # [B, text_embed_dim]
        text_embeds = self.text_proj(text_features)  # [B, embed_dim]
        return text_embeds

    def forward(self, image, text):
        image_embeds = self.encode_image(image)  # [B, embed_dim]
        text_embeds = self.encode_text(text)  # [B, embed_dim]

        image_embeds = F.normalize(image_embeds, dim=-1)
        text_embeds = F.normalize(text_embeds, dim=-1)

        logit_scale = self.logit_scale.exp()
        logits_per_image = logit_scale * image_embeds @ text_embeds.t()
        logits_per_text = logits_per_image.t()

        return logits_per_image, logits_per_text

 

위에서 보이듯 CLIP 은 단순한 구조를 가집니다.

 

이전에 정리했던 ViT 를 사용하여 이미지를 인코딩 하고,

동시에 텍스트를 BERT 와 같은 Transformer 기반 인코더로 인코딩 합니다.

 

인코딩된 결과물은 각각 이미지와 자연어의 의미가 녹아있는 Latent Vector 이므로, 비슷한 의미를 가진 정보끼리는 서로 비슷한 벡터가 나와야 합니다.

예를들어 강아지가 있는 사진은 강아지를 설명한 텍스트와 유사한 벡터가 나와야 하지요.

 

그렇기에 두 벡터간 Cosine Similarity 를 계산하여 같은 쌍으로 묶인 자연어와 이미지는 유사도가 높게, 다른 쌍은 유사도가 낮게 인코딩이 되도록 학습시키는 것이 CLIP 입니다.

 

CLIP 의 학습 방식의 핵심은 Contrastive Learning 으로, 이미지와 맞는 텍스트 쌍은 가까운 임베딩 벡터를 갖게 하고, 맞지 않는 쌍은 임베딩 공간에서도 멀어지도록 학습하는 것입니다.

 

학습 데이터는 Positive pair, 즉 서로 유사한 의미를 가지는 데이터 쌍과, Negative pair, 의미적으로 다른 데이터 쌍을 준비하고,

InfoNCE Loss 수식

 

위와 같은 InfoNCE 와 같은 CrossEntropy 계열의 손실 함수로, 이것을 사용한다는 의미는 의미가 같은지 아닌지에 대하여 분류하는 방식으로 학습을 진행한다는 것입니다.

 

예를들어,

Anchor Img Text1 Text2 Text3 Text4
이미지 A 텍스트 A 텍스트 B 텍스트 C 텍스트 D

 

위와 같이 한 이미지가 있고, 그에 대한 '하나'의 정답 텍스트가 있으며, 나머지는 전부 오답으로 하여 이미지 A 에 대해 올바른 텍스트를 찾는 문제라고 할 수 있습니다.

 

(LLaVA 설명)

[Input Image] ─> [Visual Encoder] ─┐
                                   │
                                                 
[Input Text] ─> [Text Tokenizer]  ─> [Projection] ──> [LLaMA] ──> [Text Output]

 

LLaVA 의 전체 구조는 위와 같습니다.

Image 와 Text 를 같이 입력받아서, Image 는 ViT에, Text 는 Text Encoder 에 넣고, 이를 합쳐서 LLaMA 모델에 넣어주는 것 뿐입니다.

 

- 전체 코드를 보며 설명하겠습니다.

import math
import torch
from torch import nn
import torch.nn.functional as F
from typing import Tuple, Optional


class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, emb_size: int):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.scale = math.sqrt(emb_size)

    def forward(self, tokens: torch.LongTensor) -> torch.Tensor:
        # tokens: (batch, seq_len)
        # output: (batch, seq_len, emb_size) scaled by sqrt(emb_size)
        return self.embedding(tokens) * self.scale


def rotate_half(x: torch.Tensor) -> torch.Tensor:
    # x: (..., 2*k) → return (..., 2*k)
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat((-x2, x1), dim=-1)


def build_rotary_pos_emb(dim: int, seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]:
    # dim: 헤드당 차원(d_k), seq_len: 최대 시퀀스 길이
    inv_freq = 1.0 / (
            10000 ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)
    )  # (dim/2,)
    t = torch.arange(seq_len, dtype=torch.float32)  # (seq_len,)
    freqs = torch.einsum("i,j->ij", t, inv_freq)  # (seq_len, dim/2)
    emb = torch.cat((freqs, freqs), dim=-1)  # (seq_len, dim)
    cos = emb.cos()[None, None, :, :]  # (1, 1, seq_len, dim)
    sin = emb.sin()[None, None, :, :]  # (1, 1, seq_len, dim)
    return cos, sin


def apply_rotary_pos_emb(
        q: torch.Tensor,
        k: torch.Tensor,
        cos: torch.Tensor,
        sin: torch.Tensor,
        seq_len: int,
        dtype: torch.dtype
) -> Tuple[torch.Tensor, torch.Tensor]:
    cos = cos[:, :, :seq_len, :].to(dtype=dtype)
    sin = sin[:, :, :seq_len, :].to(dtype=dtype)
    q_rot = (q * cos) + (rotate_half(q) * sin)
    k_rot = (k * cos) + (rotate_half(k) * sin)
    return q_rot, k_rot


class Attention(nn.Module):
    def __init__(self, dropout: float = 0.1):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

    def forward(
            self,
            query: torch.Tensor,
            key: torch.Tensor,
            value: torch.Tensor,
            mask: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(query.size(-1))
        if mask is not None:
            scores = scores.float()
            scores = scores.masked_fill(mask == 0, -1e9)
        p_attn = F.softmax(scores, dim=-1)
        p_attn = self.dropout(p_attn)
        context = torch.matmul(p_attn, value)
        return context, p_attn


class MultiHeadedAttention(nn.Module):
    def __init__(
            self,
            n_heads: int,
            d_model: int,
            dropout: float = 0.1,
            max_seq_len: int = 2048,
            use_rope: bool = True
    ):
        super().__init__()
        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.use_rope = use_rope

        self.lin_q = nn.Linear(d_model, d_model, bias=False)
        self.lin_k = nn.Linear(d_model, d_model, bias=False)
        self.lin_v = nn.Linear(d_model, d_model, bias=False)
        self.out_proj = nn.Linear(d_model, d_model, bias=False)

        self.attn = Attention(dropout)

        if use_rope:
            cos, sin = build_rotary_pos_emb(self.d_k, max_seq_len)
            self.register_buffer("cos", cos)  # float32
            self.register_buffer("sin", sin)  # float32

    def forward(
            self,
            x_q: torch.Tensor,
            x_k: torch.Tensor,
            x_v: torch.Tensor,
            mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        batch_size, seq_len, _ = x_q.size()
        dtype = x_q.dtype

        q = self.lin_q(x_q).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        k = self.lin_k(x_k).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        v = self.lin_v(x_v).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)

        if self.use_rope:
            q, k = apply_rotary_pos_emb(q, k, self.cos, self.sin, seq_len, dtype)

        context, _ = self.attn(q, k, v, mask)  # context: (batch, n_heads, seq_len, d_k)

        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.n_heads * self.d_k)

        out = self.out_proj(context)  # (batch, seq_len, d_model)
        return out


class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(dim))
        self.eps = eps

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        rms = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).sqrt()
        return (x / rms) * self.weight


class SublayerConnection(nn.Module):
    def __init__(self, dim: int, dropout: float):
        super().__init__()
        self.norm = RMSNorm(dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, sublayer: nn.Module) -> torch.Tensor:
        return x + self.dropout(sublayer(self.norm(x)))


class SwiGLU(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x1, x2 = x.chunk(2, dim=-1)
        return x1 * F.silu(x2)


class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff * 2, bias=False)
        self.activation = SwiGLU()
        self.linear2 = nn.Linear(d_ff, d_model, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.linear1(x)  # (batch, seq_len, 2*d_ff)
        x = self.activation(x)  # (batch, seq_len, d_ff)
        x = self.linear2(x)  # (batch, seq_len, d_model)
        return self.dropout(x)


class DecoderBlock(nn.Module):
    def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1, max_seq_len: int = 2048):
        super().__init__()
        self.self_attn = MultiHeadedAttention(n_heads, d_model, dropout, max_seq_len, use_rope=True)
        self.ffn = PositionwiseFeedForward(d_model, d_ff, dropout)
        self.sublayer_attn = SublayerConnection(d_model, dropout)
        self.sublayer_ffn = SublayerConnection(d_model, dropout)

    def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        x = self.sublayer_attn(x, lambda _x: self.self_attn(_x, _x, _x, mask))
        x = self.sublayer_ffn(x, self.ffn)
        return x


class PositionalEmbedding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 512):
        super().__init__()
        self.pos_embedding = nn.Parameter(torch.zeros(1, max_len, d_model))
        nn.init.trunc_normal_(self.pos_embedding, std=0.02)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (batch, seq_len, d_model)
        seq_len = x.size(1)
        return self.pos_embedding[:, :seq_len, :]


class PatchEmbedding(nn.Module):
    def __init__(
            self,
            in_channels: int = 3,
            patch_size: int = 16,
            emb_size: int = 768,
            img_size: int = 224
    ):
        super().__init__()
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(
            in_channels,
            emb_size,
            kernel_size=patch_size,
            stride=patch_size
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.proj(x)  # (batch, emb_size, n_patches_sqrt, n_patches_sqrt)
        x = x.flatten(2)  # (batch, emb_size, n_patches)
        return x.transpose(1, 2)  # (batch, n_patches, emb_size)


class GELU(nn.Module):
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * x.pow(3))))


class PositionwiseFeedForwardViT(nn.Module):
    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.activation = GELU()
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.linear1(x)
        x = self.activation(x)
        x = self.linear2(x)
        return self.dropout(x)


class EncoderBlockViT(nn.Module):
    def __init__(
            self,
            hidden: int,
            attn_heads: int,
            feed_forward_hidden: int,
            dropout: float,
            max_seq_len: int = 2048
    ):
        super().__init__()
        self.attn = MultiHeadedAttention(attn_heads, hidden, dropout, max_seq_len, use_rope=False)
        self.ffn = PositionwiseFeedForwardViT(hidden, feed_forward_hidden, dropout)
        self.sublayer1 = SublayerConnection(hidden, dropout)
        self.sublayer2 = SublayerConnection(hidden, dropout)

    def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        x = self.sublayer1(x, lambda _x: self.attn(_x, _x, _x, mask))
        x = self.sublayer2(x, self.ffn)
        return x


class ViT(nn.Module):
    def __init__(
            self,
            img_size: int = 224,
            patch_size: int = 16,
            in_channels: int = 3,
            emb_size: int = 768,
            n_layers: int = 12,
            attn_heads: int = 12,
            dropout: float = 0.1
    ):
        super().__init__()
        self.embedding = PatchEmbedding(in_channels, patch_size, emb_size, img_size)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, emb_size))
        num_patches = (img_size // patch_size) ** 2
        max_len = num_patches + 1
        self.pos_embedding = PositionalEmbedding(d_model=emb_size, max_len=max_len)
        self.dropout = nn.Dropout(dropout)

        self.transformer_blocks = nn.ModuleList([
            EncoderBlockViT(emb_size, attn_heads, emb_size * 4, dropout)
            for _ in range(n_layers)
        ])
        self.norm = RMSNorm(emb_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch_size = x.size(0)
        x = self.embedding(x)
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embedding(x)
        x = self.dropout(x)
        for block in self.transformer_blocks:
            x = block(x, mask=None)
        return self.norm(x)  # (batch, 1 + n_patches, emb_size)


class LLaVAConfig:
    def __init__(
            self,
            vocab_size: int = 32000,
            d_model: int = 512,
            n_layers: int = 12,
            n_heads: int = 8,
            d_ff: int = 2048,
            max_text_len: int = 512,
            img_size: int = 224,
            patch_size: int = 16,
            img_emb_size: int = 768,
            img_n_layers: int = 12,
            img_n_heads: int = 12,
            dropout: float = 0.1,
    ):
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.d_ff = d_ff
        self.max_text_len = max_text_len

        self.img_size = img_size
        self.patch_size = patch_size
        self.img_emb_size = img_emb_size
        self.img_n_layers = img_n_layers
        self.img_n_heads = img_n_heads

        self.dropout = dropout


class LLaVAModel(nn.Module):
    def __init__(self, config: LLaVAConfig):
        super().__init__()
        self.config = config

        self.vit = ViT(
            img_size=config.img_size,
            patch_size=config.patch_size,
            in_channels=3,
            emb_size=config.img_emb_size,
            n_layers=config.img_n_layers,
            attn_heads=config.img_n_heads,
            dropout=config.dropout
        )
        self.vit_norm = RMSNorm(config.img_emb_size)
        self.img_proj = nn.Linear(config.img_emb_size, config.d_model, bias=False)

        self.token_embedding = TokenEmbedding(config.vocab_size, config.d_model)
        self.dec_pos_embedding = PositionalEmbedding(d_model=config.d_model, max_len=config.max_text_len + 1)

        self.decoder_blocks = nn.ModuleList([
            DecoderBlock(
                d_model=config.d_model,
                n_heads=config.n_heads,
                d_ff=config.d_ff,
                dropout=config.dropout,
                max_seq_len=(config.max_text_len + 1)
            )
            for _ in range(config.n_layers)
        ])
        self.final_norm = RMSNorm(config.d_model)

        self.output_proj = nn.Linear(config.d_model, config.vocab_size, bias=False)
        self.lm_bias = nn.Parameter(torch.zeros(config.vocab_size))

    def forward(
            self,
            images: torch.FloatTensor,
            input_ids: torch.LongTensor,
            attention_mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        batch_size, T = input_ids.size()

        vit_feats = self.vit(images)
        cls_feats = vit_feats[:, 0, :]  # (batch, img_emb_size)
        cls_feats = self.vit_norm(cls_feats)  # (batch, img_emb_size)

        img_prefix = self.img_proj(cls_feats)  # (batch, d_model)
        img_prefix = img_prefix.unsqueeze(1)  # (batch, 1, d_model)

        txt_emb = self.token_embedding(input_ids)  # (batch, T, d_model)

        x = torch.cat([img_prefix, txt_emb], dim=1)  # (batch, 1 + T, d_model)

        x = x + self.dec_pos_embedding(x)  # (batch, 1 + T, d_model)

        full_len = T + 1
        device = x.device
        causal = torch.tril(torch.ones((full_len, full_len), device=device, dtype=torch.bool))
        causal = causal.unsqueeze(0).unsqueeze(1)  # (1, 1, full_len, full_len)

        if attention_mask is not None:
            padding_mask = attention_mask.view(batch_size, 1, 1, T).to(torch.bool)  # (batch,1,1,T)
            causal_mask = causal.expand(batch_size, -1, -1, -1).clone()  # (batch,1,full_len,full_len)
            causal_mask[:, :, 1:, 1:] = causal_mask[:, :, 1:, 1:] & padding_mask
            mask = causal_mask
        else:
            mask = causal.expand(batch_size, -1, -1, -1)  # (batch,1,full_len,full_len)

        for block in self.decoder_blocks:
            x = block(x, mask)  # x: (batch, 1+T, d_model)

        x = self.final_norm(x)

        txt_feats = x[:, 1:, :]  # (batch, T, d_model)
        logits = self.output_proj(txt_feats) + self.lm_bias  # (batch, T, vocab_size)

        return logits

 

기존 LLaMA 정리글과 ViT 정리글을 봤다면 거의 같은 코드를 사용한 것을 볼 수 있을 것입니다.

고로 LLaVA 모델 클래스를 바로 보자면,

        vit_feats = self.vit(images)
        cls_feats = vit_feats[:, 0, :]  # (batch, img_emb_size)
        cls_feats = self.vit_norm(cls_feats)  # (batch, img_emb_size)

        img_prefix = self.img_proj(cls_feats)  # (batch, d_model)
        img_prefix = img_prefix.unsqueeze(1)  # (batch, 1, d_model)

 

이런식으로 ViT 를 돌려서 인코딩 벡터를 가져오고, 이 중에서도 이미지 전체 의미를 품고 있는 CLS 토큰의 벡터를 추출합니다.

txt_emb = self.token_embedding(input_ids)  # (batch, T, d_model)

 

텍스트 의미 추출은 기존과 전혀 다를 것이 없는 방식을 사용하며,

x = torch.cat([img_prefix, txt_emb], dim=1)

 

이런 방식으로 단순히 두 벡터를 연결하는 것입니다.

 

위치 인코딩은,

x = x + self.dec_pos_embedding(x)

 

이렇게, 두 벡터를 합친 상태에서 동시에 적용합니다. (즉, 이미지는 늘 텍스트보다 앞에 위치한 것으로 처리되는 문제가 있음.)

 

이후로는 LLaMA 의 Decoder Block 에 x 값을 입력하여 결과물을 만드는 것 뿐입니다.

 

(이상입니다.)

- 앞서서 정리한 내용들을 바탕으로 하여 변경된 부분만 살펴보니 생각보다 훨씬 쉬운 기술인 것 같습니다.

 

이번에는 이미지 의미를 추출하는 ViT 와 의미 벡터가 있으면 자연어를 생성하는 GPT 모델의 결합물이라는 의미에서 LLaVA 의 멀티 모달을 이해할 수 있었으며,

나아가서는 Stable Diffusion 과 같은 이미지 생성 모델을 공부한다면 이제까지 배운 멀티 모달 기술이라던지 LLM 의 향상된 Attention 기능과 같은 것들을 응용하여 최신의 이미지 생성 모델의 구조와 원리를 쉽게 이해할 수 있을 것 같습니다.

 

최종적으로는 자연어, 이미지, 음성, 그외 아날로그 신호 등을 모두 처리할 수 있도록 관련 모델들을 하나씩 정리해나갈 생각으로,

결국 어떠한 문제도 인코더와 디코더라는 큰 틀에서 실행 속도와 정확성의 기준하에 발전하는 것이므로, 각각 기술을 차례대로 따라가다보면 어려울 것은 없을 것 같네요.

 

- 다음 게시글은 DeepSeek-VL2 모델을 정리할 예정입니다.

위에서 설명한 LLaVA 의 이미지 처리 방식에 존재하는 어떠한 문제를 해결하여 이미지 인식 성능 향상, 작은 이미지 인식 향상, OCR 및 이미지 속 자연어 의미 추출까지 가능하도록 성능을 끌어올린 모델인데, 어떤 방식으로 개선을 이루어낸 것인지에 대해 알아볼 것입니다.