Skip to content

Commit 6251835

Browse files
committed
Add xception model
1 parent 6916146 commit 6251835

11 files changed

+450
-0
lines changed

keras_hub/api/layers/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -81,3 +81,6 @@
8181
from keras_hub.src.models.whisper.whisper_audio_converter import (
8282
WhisperAudioConverter,
8383
)
84+
from keras_hub.src.models.xception.xception_image_converter import (
85+
XceptionImageConverter,
86+
)

keras_hub/api/models/__init__.py

+7
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,13 @@
376376
from keras_hub.src.models.vit_det.vit_det_backbone import ViTDetBackbone
377377
from keras_hub.src.models.whisper.whisper_backbone import WhisperBackbone
378378
from keras_hub.src.models.whisper.whisper_tokenizer import WhisperTokenizer
379+
from keras_hub.src.models.xception.xception_backbone import XceptionBackbone
380+
from keras_hub.src.models.xception.xception_image_classifier import (
381+
XceptionImageClassifier,
382+
)
383+
from keras_hub.src.models.xception.xception_image_classifier_preprocessor import (
384+
XceptionImageClassifierPreprocessor,
385+
)
379386
from keras_hub.src.models.xlm_roberta.xlm_roberta_backbone import (
380387
XLMRobertaBackbone,
381388
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from keras_hub.src.models.xception.xception_backbone import XceptionBackbone
2+
from keras_hub.src.models.xception.xception_presets import backbone_presets
3+
from keras_hub.src.utils.preset_utils import register_presets
4+
5+
register_presets(backbone_presets, XceptionBackbone)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
import functools
2+
3+
from keras import layers
4+
5+
from keras_hub.src.api_export import keras_hub_export
6+
from keras_hub.src.models.backbone import Backbone
7+
from keras_hub.src.utils.keras_utils import standardize_data_format
8+
9+
10+
@keras_hub_export("keras_hub.models.XceptionBackbone")
11+
class XceptionBackbone(Backbone):
12+
"""Xception core network with hyperparameters.
13+
14+
This class implements a Xception backbone as described in
15+
[Xception: Deep Learning with Depthwise Separable Convolutions](https://arxiv.org/abs/1610.02357).
16+
17+
Most users will want the pretrained presets available with this model. If
18+
you are creating a custom backbone, this model provides customizability
19+
through the `stackwise_conv_filters` and `stackwise_pooling` arguments. This
20+
backbone assumes the same basic structure as the original Xception mode:
21+
* Residuals and pre-activation everywhere but the first and last block.
22+
* Conv layers for the first block only, separable conv layers elsewhere.
23+
24+
Args:
25+
stackwise_conv_filters: list of list of ints. Each outermost list
26+
entry represents a block, and each innermost list entry a conv
27+
layer. The integer value specifies the number of filters for the
28+
conv layer.
29+
stackwise_pooling: list of bools. A list of booleans per block, where
30+
each entry is true if the block should includes a max pooling layer
31+
and false if it should not.
32+
image_shape: tuple. The input shape without the batch size.
33+
Defaults to `(None, None, 3)`.
34+
data_format: `None` or str. If specified, either `"channels_last"` or
35+
`"channels_first"`. If unspecified, the Keras default will be used.
36+
dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype
37+
to use for the model's computations and weights.
38+
39+
Examples:
40+
```python
41+
input_data = np.random.uniform(0, 1, size=(2, 224, 224, 3))
42+
43+
# Pretrained Xception backbone.
44+
model = keras_hub.models.Backbone.from_preset("exception_41_imagenet")
45+
model(input_data)
46+
47+
# Randomly initialized Xception backbone with a custom config.
48+
model = keras_hub.models.XceptionBackbone(
49+
stackwise_conv_filters=[[32, 64], [64, 128], [256, 256]],
50+
stackwise_pooling=[True, True, False],
51+
)
52+
model(input_data)
53+
```
54+
"""
55+
56+
def __init__(
57+
self,
58+
stackwise_conv_filters,
59+
stackwise_pooling,
60+
image_shape=(None, None, 3),
61+
data_format=None,
62+
dtype=None,
63+
**kwargs,
64+
):
65+
if len(stackwise_conv_filters) != len(stackwise_pooling):
66+
raise ValueError("All stackwise args should have the same length.")
67+
68+
data_format = standardize_data_format(data_format)
69+
channel_axis = -1 if data_format == "channels_last" else 1
70+
num_blocks = len(stackwise_conv_filters)
71+
72+
# Layer shorcuts with common args.
73+
norm = functools.partial(
74+
layers.BatchNormalization,
75+
axis=channel_axis,
76+
dtype=dtype,
77+
)
78+
act = functools.partial(
79+
layers.Activation,
80+
activation="relu",
81+
dtype=dtype,
82+
)
83+
conv = functools.partial(
84+
layers.Conv2D,
85+
kernel_size=(3, 3),
86+
use_bias=False,
87+
data_format=data_format,
88+
dtype=dtype,
89+
)
90+
sep_conv = functools.partial(
91+
layers.SeparableConv2D,
92+
kernel_size=(3, 3),
93+
padding="same",
94+
use_bias=False,
95+
data_format=data_format,
96+
dtype=dtype,
97+
)
98+
point_conv = functools.partial(
99+
layers.Conv2D,
100+
kernel_size=(1, 1),
101+
strides=(2, 2),
102+
padding="same",
103+
use_bias=False,
104+
data_format=data_format,
105+
dtype=dtype,
106+
)
107+
pool = functools.partial(
108+
layers.MaxPool2D,
109+
pool_size=(3, 3),
110+
strides=(2, 2),
111+
padding="same",
112+
data_format=data_format,
113+
dtype=dtype,
114+
)
115+
116+
# === Functional Model ===
117+
image_input = layers.Input(shape=image_shape)
118+
x = image_input # Intermediate result.
119+
120+
# Iterate through the blocks.
121+
for block_i in range(num_blocks):
122+
first_block, last_block = block_i == 0, block_i == num_blocks - 1
123+
block_filters = stackwise_conv_filters[block_i]
124+
use_pooling = stackwise_pooling[block_i]
125+
126+
# Save the block input as a residual.
127+
residual = x
128+
for conv_i, filters in enumerate(block_filters):
129+
# First block has post activation and strides on first conv.
130+
if first_block:
131+
prefix = f"block{block_i + 1}_conv{conv_i + 1}"
132+
strides = (2, 2) if conv_i == 0 else (1, 1)
133+
x = conv(filters, strides=strides, name=prefix)(x)
134+
x = norm(name=f"{prefix}_bn")(x)
135+
x = act(name=f"{prefix}_act")(x)
136+
# Last block has post activation.
137+
elif last_block:
138+
prefix = f"block{block_i + 1}_sepconv{conv_i + 1}"
139+
x = sep_conv(filters, name=prefix)(x)
140+
x = norm(name=f"{prefix}_bn")(x)
141+
x = act(name=f"{prefix}_act")(x)
142+
else:
143+
prefix = f"block{block_i + 1}_sepconv{conv_i + 1}"
144+
# The first conv in second block has no activation.
145+
if block_i != 1 or conv_i != 0:
146+
x = act(name=f"{prefix}_act")(x)
147+
x = sep_conv(filters, name=prefix)(x)
148+
x = norm(name=f"{prefix}_bn")(x)
149+
150+
# Optional block pooling.
151+
if use_pooling:
152+
x = pool(name=f"block{block_i + 1}_pool")(x)
153+
154+
# Sum residual, first and last block do not have a residual.
155+
if not first_block and not last_block:
156+
prefix = f"block{block_i + 1}_residual"
157+
filters = x.shape[channel_axis]
158+
# Match filters with a pointwise conv if needed.
159+
if filters != residual.shape[channel_axis]:
160+
residual = point_conv(filters, name=f"{prefix}_conv")(
161+
residual
162+
)
163+
residual = norm(name=f"{prefix}_bn")(residual)
164+
x = layers.Add(name=f"{prefix}_add", dtype=dtype)([x, residual])
165+
166+
super().__init__(
167+
inputs=image_input,
168+
outputs=x,
169+
dtype=dtype,
170+
**kwargs,
171+
)
172+
173+
# === Config ===
174+
self.stackwise_conv_filters = stackwise_conv_filters
175+
self.stackwise_pooling = stackwise_pooling
176+
self.image_shape = image_shape
177+
self.data_format = data_format
178+
179+
def get_config(self):
180+
config = super().get_config()
181+
config.update(
182+
{
183+
"stackwise_conv_filters": self.stackwise_conv_filters,
184+
"stackwise_pooling": self.stackwise_pooling,
185+
"image_shape": self.image_shape,
186+
}
187+
)
188+
return config
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import pytest
2+
from keras import ops
3+
4+
from keras_hub.src.models.xception.xception_backbone import XceptionBackbone
5+
from keras_hub.src.tests.test_case import TestCase
6+
7+
8+
class XceptionBackboneTest(TestCase):
9+
def setUp(self):
10+
self.init_kwargs = {
11+
"stackwise_conv_filters": [[32, 64], [128, 128], [256, 256]],
12+
"stackwise_pooling": [False, True, False],
13+
"image_shape": (None, None, 3),
14+
}
15+
self.input_size = 64
16+
self.input_data = ops.ones((2, self.input_size, self.input_size, 3))
17+
18+
def test_backbone_basics(self):
19+
self.run_vision_backbone_test(
20+
cls=XceptionBackbone,
21+
init_kwargs=self.init_kwargs,
22+
input_data=self.input_data,
23+
expected_output_shape=(2, 15, 15, 256),
24+
)
25+
26+
@pytest.mark.large
27+
def test_saved_model(self):
28+
self.run_model_saving_test(
29+
cls=XceptionBackbone,
30+
init_kwargs=self.init_kwargs,
31+
input_data=self.input_data,
32+
)
33+
34+
@pytest.mark.extra_large
35+
def test_all_presets(self):
36+
for preset in XceptionBackbone.presets:
37+
self.run_preset_test(
38+
cls=XceptionBackbone,
39+
preset=preset,
40+
input_data=self.input_data,
41+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from keras_hub.src.api_export import keras_hub_export
2+
from keras_hub.src.models.image_classifier import ImageClassifier
3+
from keras_hub.src.models.xception.xception_backbone import XceptionBackbone
4+
from keras_hub.src.models.xception.xception_image_classifier_preprocessor import ( # noqa: E501
5+
XceptionImageClassifierPreprocessor,
6+
)
7+
8+
9+
@keras_hub_export("keras_hub.models.XceptionImageClassifier")
10+
class XceptionImageClassifier(ImageClassifier):
11+
backbone_cls = XceptionBackbone
12+
preprocessor_cls = XceptionImageClassifierPreprocessor
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from keras_hub.src.api_export import keras_hub_export
2+
from keras_hub.src.models.image_classifier_preprocessor import (
3+
ImageClassifierPreprocessor,
4+
)
5+
from keras_hub.src.models.xception.xception_backbone import XceptionBackbone
6+
from keras_hub.src.models.xception.xception_image_converter import (
7+
XceptionImageConverter,
8+
)
9+
10+
11+
@keras_hub_export("keras_hub.models.XceptionImageClassifierPreprocessor")
12+
class XceptionImageClassifierPreprocessor(ImageClassifierPreprocessor):
13+
backbone_cls = XceptionBackbone
14+
image_converter_cls = XceptionImageConverter
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import numpy as np
2+
import pytest
3+
4+
from keras_hub.src.models.xception.xception_backbone import XceptionBackbone
5+
from keras_hub.src.models.xception.xception_image_classifier import (
6+
XceptionImageClassifier,
7+
)
8+
from keras_hub.src.models.xception.xception_image_classifier_preprocessor import ( # noqa: E501
9+
XceptionImageClassifierPreprocessor,
10+
)
11+
from keras_hub.src.models.xception.xception_image_converter import (
12+
XceptionImageConverter,
13+
)
14+
from keras_hub.src.tests.test_case import TestCase
15+
16+
17+
class XceptionImageClassifierTest(TestCase):
18+
def setUp(self):
19+
self.images = np.ones((2, 299, 299, 3))
20+
self.labels = [0, 1]
21+
self.backbone = XceptionBackbone(
22+
stackwise_conv_filters=[[32, 64], [128, 128], [256, 256]],
23+
stackwise_pooling=[False, True, False],
24+
)
25+
self.image_converter = XceptionImageConverter(
26+
image_size=(299, 299),
27+
scale=1.0 / 127.5,
28+
offset=-1.0,
29+
)
30+
self.prepocessor = XceptionImageClassifierPreprocessor(
31+
image_converter=self.image_converter,
32+
)
33+
self.init_kwargs = {
34+
"backbone": self.backbone,
35+
"preprocessor": self.prepocessor,
36+
"num_classes": 2,
37+
"pooling": "avg",
38+
"activation": "softmax",
39+
}
40+
self.train_data = (self.images, self.labels)
41+
42+
def test_classifier_basics(self):
43+
self.run_task_test(
44+
cls=XceptionImageClassifier,
45+
init_kwargs=self.init_kwargs,
46+
train_data=self.train_data,
47+
expected_output_shape=(2, 2),
48+
)
49+
50+
def test_head_dtype(self):
51+
model = XceptionImageClassifier(
52+
**self.init_kwargs, head_dtype="bfloat16"
53+
)
54+
self.assertEqual(model.output_dense.compute_dtype, "bfloat16")
55+
56+
@pytest.mark.large
57+
def test_smallest_preset(self):
58+
# Test that our forward pass is stable!
59+
image_batch = self.load_test_image()[None, ...].astype("float32")
60+
image_batch = self.image_converter(image_batch)
61+
self.run_preset_test(
62+
cls=XceptionImageClassifier,
63+
preset="xception_41_imagenet",
64+
input_data=image_batch,
65+
expected_output_shape=(1, 1000),
66+
expected_labels=[85],
67+
)
68+
69+
@pytest.mark.large
70+
def test_saved_model(self):
71+
self.run_model_saving_test(
72+
cls=XceptionImageClassifier,
73+
init_kwargs=self.init_kwargs,
74+
input_data=self.images,
75+
)
76+
77+
@pytest.mark.extra_large
78+
def test_all_presets(self):
79+
for preset in XceptionImageClassifier.presets:
80+
self.run_preset_test(
81+
cls=XceptionImageClassifier,
82+
preset=preset,
83+
init_kwargs={"num_classes": 2},
84+
input_data=self.images,
85+
expected_output_shape=(2, 2),
86+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from keras_hub.src.api_export import keras_hub_export
2+
from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
3+
from keras_hub.src.models.xception.xception_backbone import XceptionBackbone
4+
5+
6+
@keras_hub_export("keras_hub.layers.XceptionImageConverter")
7+
class XceptionImageConverter(ImageConverter):
8+
backbone_cls = XceptionBackbone
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
"""Xception preset configurations."""
2+
3+
backbone_presets = {
4+
"xception_41_imagenet": {
5+
"metadata": {
6+
"description": (
7+
"41-layer Xception model pre-trained on ImageNet 1k."
8+
),
9+
"params": 20861480,
10+
"path": "xception",
11+
},
12+
"kaggle_handle": "kaggle://keras/xception/keras/xception_41_imagenet/2",
13+
},
14+
}

0 commit comments

Comments
 (0)