| import sys |
|
|
| import torch |
| from einops import rearrange |
| from typing import Optional, Tuple, Union |
|
|
| from torch import nn |
| from transformers import CLIPModel as HFCLIPModel, CLIPVisionConfig |
| from transformers.modeling_outputs import BaseModelOutputWithPooling |
| from transformers.models.clip.modeling_clip import CLIP_VISION_INPUTS_DOCSTRING |
| from transformers.utils import replace_return_docstrings, add_start_docstrings_to_model_forward |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| from model.process_clip import get_global_value, set_global_value |
|
|
|
|
| def SET_GLOBAL_VALUE(k, v): |
| set_global_value(k, v) |
|
|
| class CLIPVisionEmbeddings(nn.Module): |
| def __init__(self, config: CLIPVisionConfig): |
| super().__init__() |
| self.config = config |
| self.embed_dim = config.hidden_size |
| self.image_size = config.image_size |
| self.patch_size = config.patch_size |
|
|
| self.class_embedding = nn.Parameter(torch.randn(self.embed_dim)) |
|
|
| self.patch_embedding = nn.Conv2d( |
| in_channels=config.num_channels, |
| out_channels=self.embed_dim, |
| kernel_size=self.patch_size, |
| stride=self.patch_size, |
| bias=False, |
| ) |
|
|
| self.num_patches = (self.image_size // self.patch_size) ** 2 |
| self.num_positions = self.num_patches + 1 |
| self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) |
| self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1))) |
|
|
| def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: |
| |
| batch_size = pixel_values.shape[0] |
| patch_embeds = self.patch_embedding(pixel_values) |
| patch_embeds = patch_embeds.flatten(2).transpose(1, 2) |
|
|
| class_embeds = self.class_embedding.expand(batch_size, 1, -1) |
| embeddings = torch.cat([class_embeds, patch_embeds], dim=1) |
| embeddings = embeddings + self.position_embedding(self.position_ids) |
| return embeddings |
|
|
| class CLIPVisionEmbeddings3D(nn.Module): |
| def __init__(self, config: CLIPVisionConfig): |
| super().__init__() |
| self.config = config |
| self.embed_dim = config.hidden_size |
| self.image_size = config.image_size |
| self.patch_size = config.patch_size |
| self.num_frames = config.num_frames |
| self.tube_size = config.tube_size |
|
|
| self.class_embedding = nn.Parameter(torch.randn(self.embed_dim)) |
|
|
| self.patch_embedding = nn.Conv2d( |
| in_channels=config.num_channels, |
| out_channels=self.embed_dim, |
| kernel_size=self.patch_size, |
| stride=self.patch_size, |
| bias=False, |
| ) |
|
|
| self.num_patches = (self.image_size // self.patch_size) ** 2 |
| self.num_positions = self.num_patches + 1 |
| self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) |
| self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1))) |
|
|
| def expand3d(self): |
|
|
| state_dict = self.patch_embedding.state_dict() |
| state_dict_expand = state_dict['weight'].unsqueeze(2) |
| device, dtype = state_dict_expand.device, state_dict_expand.dtype |
| |
|
|
| zero = torch.zeros_like(state_dict_expand).to(device=device, dtype=dtype) |
| state_dict_expand3d = torch.cat([state_dict_expand] + (self.tube_size-1)*[zero], dim=2) |
|
|
| |
|
|
| patch_embedding = nn.Conv3d( |
| in_channels=self.patch_embedding.in_channels, |
| out_channels=self.embed_dim, |
| kernel_size=(self.tube_size, self.patch_size, self.patch_size), |
| stride=(self.tube_size, self.patch_size, self.patch_size), |
| bias=False, |
| ).to(device=device, dtype=dtype) |
| patch_embedding.load_state_dict({'weight': state_dict_expand3d}) |
| self.patch_embedding = patch_embedding |
|
|
|
|
| class_embedding = nn.Parameter(self.class_embedding.data.repeat(self.num_frames // self.tube_size, 1)).to(device=device, dtype=dtype) |
| self.class_embedding = class_embedding |
|
|
|
|
| def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: |
| |
| batch_size = pixel_values.shape[0] // self.num_frames |
| pixel_values = rearrange(pixel_values, '(b t) c h w -> b c t h w', b=batch_size, t=self.num_frames) |
| |
| patch_embeds = self.patch_embedding(pixel_values) |
| |
| |
| patch_embeds = rearrange(patch_embeds, 'b c t h w -> b t (h w) c') |
|
|
| class_embeds = self.class_embedding.unsqueeze(1).unsqueeze(0).repeat(batch_size, 1, 1, 1) |
| |
| |
| embeddings = torch.cat([class_embeds, patch_embeds], dim=2) |
| embeddings = embeddings + self.position_embedding(self.position_ids) |
| embeddings = rearrange(embeddings, 'b t hw_1 c -> (b t) hw_1 c') |
| return embeddings |
|
|
| class CLIPModel(HFCLIPModel): |
| def __init__(self, config, num_frames, add_time_attn, vl_new, tube_size): |
| super(CLIPModel, self).__init__(config) |
| config.vision_config.num_frames = num_frames |
| config.vision_config.tube_size = tube_size |
| if add_time_attn: |
| if vl_new: |
| self.vision_model.embeddings = CLIPVisionEmbeddings3D(config.vision_config) |
| else: |
| self.vision_model.embeddings = CLIPVisionEmbeddings(config.vision_config) |
| self.T = config.vision_config.num_frames // config.vision_config.tube_size |
| self.vision_model.forward = self.vision_model_forward |
|
|
| @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING) |
| @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig) |
| def vision_model_forward( |
| self, |
| pixel_values: Optional[torch.FloatTensor] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> Union[Tuple, BaseModelOutputWithPooling]: |
| r""" |
| Returns: |
| |
| """ |
| output_attentions = output_attentions if output_attentions is not None else self.vision_model.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.vision_model.config.output_hidden_states |
| ) |
| return_dict = return_dict if return_dict is not None else self.vision_model.config.use_return_dict |
|
|
| if pixel_values is None: |
| raise ValueError("You have to specify pixel_values") |
|
|
| if len(pixel_values.shape) == 7: |
| b_new, pair_new, T, bs_new, channel_new, h_new, w_new = pixel_values.shape |
| |
| B = b_new * pair_new * bs_new |
| pixel_values = pixel_values.reshape(B*T, channel_new, h_new, w_new) |
|
|
| elif len(pixel_values.shape) == 5: |
| B, _, T, _, _ = pixel_values.shape |
| |
| pixel_values = rearrange(pixel_values, 'b c t h w -> (b t) c h w') |
| else: |
| |
| B, _, _, _ = pixel_values.shape |
| T = 1 |
| hidden_states = self.vision_model.embeddings(pixel_values) |
| |
| |
| |
| |
| |
| |
| |
| T = self.T |
| |
| hidden_states = self.vision_model.patch_dropout(hidden_states, B, T) |
| |
| hidden_states = self.vision_model.pre_layrnorm(hidden_states) |
|
|
| encoder_outputs = self.vision_model.encoder( |
| inputs_embeds=hidden_states, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| last_hidden_state = encoder_outputs[0] |
| pooled_output = last_hidden_state[:, 0, :] |
| pooled_output = self.vision_model.post_layernorm(pooled_output) |
|
|
| pooled_output = pooled_output.reshape(B, T, -1).mean(1) |
| if not return_dict: |
| return (last_hidden_state, pooled_output) + encoder_outputs[1:] |
|
|
| return BaseModelOutputWithPooling( |
| last_hidden_state=last_hidden_state, |
| pooler_output=pooled_output, |
| hidden_states=encoder_outputs.hidden_states, |
| attentions=encoder_outputs.attentions, |
| ) |
|
|
| def encode_image(self, image, normalize: bool = False): |
| vision_outputs = self.vision_model( |
| pixel_values=image, |
| return_dict=True, |
| ) |
| image_embeds = vision_outputs[1] |
| image_embeds = self.visual_projection(image_embeds) |
| return image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) if normalize else image_embeds |
|
|
| def encode_text(self, input_ids, attention_mask, normalize: bool = False): |
| text_outputs = self.text_model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| return_dict=True, |
| ) |
| text_embeds = text_outputs[1] |
| text_embeds = self.text_projection(text_embeds) |
| return text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) if normalize else text_embeds |
|
|
|
|
| def forward( |
| self, |
| image=None, |
| input_ids=None, attention_mask=None |
| ): |
| image_features = self.encode_image(image, normalize=True) if image is not None else None |
| text_features = self.encode_text(input_ids, attention_mask, normalize=True) if input_ids is not None else None |
| |
| return { |
| "image_features": image_features, |
| "text_features": text_features, |
| "logit_scale": self.logit_scale.exp() |
| } |
| |
|
|
|
|
|
|