Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose blip2qformer #37254

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
("blenderbot-small", "BlenderbotSmallConfig"),
("blip", "BlipConfig"),
("blip-2", "Blip2Config"),
("blip_2_qformer", "Blip2QFormerConfig"),
("bloom", "BloomConfig"),
("bridgetower", "BridgeTowerConfig"),
("bros", "BrosConfig"),
Expand Down Expand Up @@ -389,6 +390,7 @@
("blenderbot-small", "BlenderbotSmall"),
("blip", "BLIP"),
("blip-2", "BLIP-2"),
("blip_2_qformer", "BLIP-2 QFormer"),
("bloom", "BLOOM"),
("bort", "BORT"),
("bridgetower", "BridgeTower"),
Expand Down Expand Up @@ -776,6 +778,7 @@
("rt_detr_resnet", "rt_detr"),
("granitevision", "llava_next"),
("sam_vision_model", "sam"),
("blip_2_qformer", "blip_2"),
]
)

Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
("blenderbot-small", "BlenderbotSmallModel"),
("blip", "BlipModel"),
("blip-2", "Blip2Model"),
("blip_2_qformer", "Blip2QFormerModel"),
("bloom", "BloomModel"),
("bridgetower", "BridgeTowerModel"),
("bros", "BrosModel"),
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/blip_2/configuration_blip_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ class Blip2QFormerConfig(PretrainedConfig):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
The epsilon used by the layer normalization layers.
pad_token_id (`int`, *optional*, defaults to 0):
Index to be used for padding token.
position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
Expand Down
101 changes: 80 additions & 21 deletions src/transformers/models/blip_2/modeling_blip_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,21 @@ def _init_weights(self, module):
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""

BLIP_2_QFORMER_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)

This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.

Parameters:
config ([`Blip2QFormerConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""

BLIP_2_VISION_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Expand Down Expand Up @@ -621,6 +636,60 @@ def _init_weights(self, module):
"""


BLIP2_QFORMER_INPUTS_DOCSTRING = r"""
Args:
query_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Hidden states to be used in the attention computation. If cross-attention,
will be used for the query (i.e., key and value will use the encoder_hidden_states).

query_length (`int`, *optional*):
Length of the query, usually based on the number of query tokens.
If no value is provided, query_length will be inferred by the query_embeds.

attention_mask (`torch.FloatTensor`, *optional*):
Attention mask of size `(batch, sequence_length)` where padding elements
are indicated by 0.

head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:

- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.

encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
the model is configured as a decoder.

encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.

past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of:
shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and
value hidden states of the attention blocks. Can be used to speed up decoding. If `past_key_values` are
used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key
value states given to this model) of shape `(batch_size, 1)` instead of all `decoder_input_ids` of shape
`(batch_size, sequence_length)`.

use_cache (`bool`, `optional`):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).

output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.

output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.

return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""


# Copied from transformers.models.blip.modeling_blip.BlipEncoder with Blip->Blip2
class Blip2Encoder(nn.Module):
"""
Expand Down Expand Up @@ -1248,11 +1317,13 @@ def forward(
return embeddings


class Blip2QFormerModel(Blip2PreTrainedModel):
"""
Querying Transformer (Q-Former), used in BLIP-2.
@add_start_docstrings(
"""

BLIP-2 Querying Transformer (Q-Former).
""",
BLIP_2_QFORMER_START_DOCSTRING,
)
class Blip2QFormerModel(Blip2PreTrainedModel):
def __init__(self, config: Blip2QFormerConfig):
super().__init__(config)
self.config = config
Expand Down Expand Up @@ -1323,6 +1394,10 @@ def get_extended_attention_mask(
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
return extended_attention_mask

@add_start_docstrings_to_model_forward(BLIP2_QFORMER_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=BaseModelOutputWithPoolingAndCrossAttentions, config_class=Blip2QFormerConfig
)
def forward(
self,
query_embeds: torch.FloatTensor,
Expand All @@ -1338,23 +1413,7 @@ def forward(
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
the model is configured as a decoder.
encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of:
shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and
value hidden states of the attention blocks. Can be used to speed up decoding. If `past_key_values` are
used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key
value states given to this model) of shape `(batch_size, 1)` instead of all `decoder_input_ids` of shape
`(batch_size, sequence_length)`.
use_cache (`bool`, `optional`):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
Returns:
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand Down
1 change: 0 additions & 1 deletion utils/check_docstrings.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@
"BlenderbotSmallConfig",
"BlenderbotSmallTokenizerFast",
"BlenderbotTokenizerFast",
"Blip2QFormerConfig",
"Blip2VisionConfig",
"BlipTextConfig",
"BlipVisionConfig",
Expand Down
1 change: 0 additions & 1 deletion utils/check_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,6 @@
"ClapAudioModelWithProjection",
"Blip2TextModelWithProjection",
"Blip2VisionModelWithProjection",
"Blip2QFormerModel",
"Blip2VisionModel",
"ErnieMForInformationExtraction",
"FastSpeech2ConformerHifiGan",
Expand Down