Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding functional CompositeLayer #21099

Open
wants to merge 18 commits into
base: master
Choose a base branch
from

Conversation

martin-gorner
Copy link
Contributor

@martin-gorner martin-gorner commented Mar 27, 2025

Introduction

CompositeLayer encapsulates a functional subgraph of layers. It is the equivalent of a Functional or Sequential model but as a lighter-weight component, without the model-specific functionality like .fit() etc.

Apart from offering a useful modeling component, one of the ultimate goals of this functionality is to allow programmatic edits on pre-trained models, like adding LoRA-like weights. The current implementation hard-codes LoRA deep inside the Dense layer. This design does not allow users to change the implementation, which is already significantly behind SOTA.

☆☆☆ Demo colab here ☆☆☆

The colab shows how to add LoRA or SVF, two parameter-efficient fine-tuning techniques to a pretrained model. For LoRA:

  • LoRA wrapper layer with the math: 18 lines of code
  • Patching the model: 7 lines of code
  • Unpatching and folding LoRA weights back: 10 lines of code

For SVF - Singular Value fine-tuning which is an SVD-based version of LoRA:

  • SVF wrapper layer with the math: 17 lines of code
  • Patching the model: 7 lines of code
  • Unpatching and folding SVF weights back: 12 lines of code

And all the code for doing this uses user-level APIs!

API

A CompositeLayer can be created either from a list of layers or a function that defines a graph of layers. There is no
constructor similar to a functional Model(inputs, outputs) because inputs and outputs are usually not known when creating a layer. They will be created when the layer is built.

# Composite layer using a function    
def layer_fn(x):
    x = layers.Dense(64, activation='relu')(x)
     outputs = layers.Dense(32)(x)
    return outputs

composite = layers.CompositeLayer(layer_fn)
# Composite layer from a list of layers
composite = layers.CompositeLayer([
    layers.Dense(64, activation='relu'),
    layers.Dense(32)
])

Implementation notes

  1. Functional and CompositeLayer only depend on Function. There is no circular dependency between models and layers.

  2. The current implementation an intermediate step to make reviewing (diffs) easier. It isolates 4 functions in functional.py that are used by both CompositeLayer and Functional:

    1. compute_input_spec
    2. run_through_graph_with_training_and_mask
    3. function_from_config
    4. serialize_functional_config

      With this approach, no changes are made to the Functional Model class hierarchy.

      The next step would be to move these 4 functions to CompositeLayer, then base Functional on CompositeLayer instead of Function. This will also allow the unification of Functional and Sequential models since both will be based on CompositeLayer and have a Function once build() is called. Code explicitly testing for Functional or Sequential can then be removed throughout the code base and replaced with isinstance(obj, CompositeLayer) and obj.built
  3. plot_model and clone_model functionality were adjusted to work with CompositeLayers

  4. Tests were added for the main edge cases, namely subclasses of CompositeLayer and Functional with the layer graph instantiated inside or outside of the class, in various nesting scenarios, tested for serialization/deserialization and for processing through clone_model.
    Three bug fixes in Functional and clone_model were needed for the tests to pass:

    1. Cleanup of return values in Functional. A list of one element was sometimes returned. Model.assert_input_compatible let this through but Function.assert_input_compatible, which is stricter, did not. Single tensors are now returned in this case.
    2. Cleanup of "is this a functional-like construct?" tests in Functional which were buggy because, surprisingly, inspect.getfullargspec(Functional.__init__) returns an empty list instead of the expected (inputs, outputs) signature (see this colab).
    3. In clone_model subclasses of Functional are typically cloned as vanilla Functional. The same conservative behavior was adopted for CompositeLayer. There was a carve-out however for subclasses of Functional with a "functional-like" constructor, again with a buggy test. This is a niche use case of the niche clone_model functionality so the carve-out was simply removed for simplicity.
  5. Passing a list of inputs to a model expecting a dictionary of inputs seems to be allowed, as long as flattening the dict does not result in reordering. There is an explicit reordering test in functional._standardize_inputs (look for "sort"). Changing this in Functional is not possible at this point but I would consider disallowing this in CompositeLayer. Tests covering this behavior are functional_test.test_list_input_with_dict_build and composite_layer_test.test_list_input_with_dict_build. (Point for discussion)

  6. In functional.py, serialization and deserialization functions serialize_functional_config and function_from_config there is an explicit test for the type Functional which triggers a node_index adjustment. This was left untouched and not updated for CompositeLayer as I have not been able to find a situation where this code is useful. A test that triggers this condition was added in functions_test.py but it passes whether the conditional clauses are there or not.

  7. Optional inputs are supported although no new API was added for them. They can be declared by setting a manual input spec on the CompositeLayer:

# declare the first arg as optional
layer = CompositeLayer(...)
input_spec = [
    InputSpec(shape=(None, 2), optional=True),
    InputSpec(shape=(None, 2)),
]
layer.input_spec = input_spec
layer([None, value]) # this will now work

See composite_layer_test.py:test_optional_inputs for a working example.

Point for discussion: is this user-friendly enough or is a new API required?

@codecov-commenter
Copy link

codecov-commenter commented Mar 27, 2025

Codecov Report

Attention: Patch coverage is 80.83832% with 64 lines in your changes missing coverage. Please review.

Project coverage is 82.72%. Comparing base (6d26efb) to head (fbcbad5).

Files with missing lines Patch % Lines
keras/src/utils/model_visualization.py 0.00% 27 Missing ⚠️
keras/src/models/functional.py 85.61% 8 Missing and 13 partials ⚠️
keras/src/layers/core/composite_layer.py 90.40% 7 Missing and 5 partials ⚠️
keras/src/models/cloning.py 80.95% 2 Missing and 2 partials ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #21099      +/-   ##
==========================================
+ Coverage   82.69%   82.72%   +0.02%     
==========================================
  Files         564      565       +1     
  Lines       54132    54285     +153     
  Branches     8411     8438      +27     
==========================================
+ Hits        44765    44907     +142     
- Misses       7294     7308      +14     
+ Partials     2073     2070       -3     
Flag Coverage Δ
keras 82.53% <80.83%> (+0.02%) ⬆️
keras-jax 64.15% <80.53%> (+0.07%) ⬆️
keras-numpy 59.23% <77.84%> (+0.13%) ⬆️
keras-openvino 33.72% <74.55%> (+0.84%) ⬆️
keras-tensorflow 64.43% <80.83%> (+0.06%) ⬆️
keras-torch 64.12% <80.53%> (+<0.01%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Comment on lines 212 to 220
# A subclassed Functional model is always cloned
# as a vanilla Functional model.
new_model = Functional(cloned_inputs, cloned_outputs,
name=model.name)
if model.compiled:
compiled_config = model.get_compile_config()
new_model.compile_from_config(compiled_config)
return new_model

Copy link
Contributor Author

@martin-gorner martin-gorner Mar 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This piece of code was moved here from the end of _clone_functional_model when the function was renames _clone_function_object and repurposed for all functional types.

The test for functional_like_constructor(model.class) was removed. See implementation note 4.iii

Comment on lines 879 to 909
# This test is permissive. Any argument combination that
# could be a Functional init is allowed. This test will be
# followed by an actual call of the Functional constructor
# so the worst case is that args are not what they should
# be and the constructor fails with an explicit error message.
return (
(len(args) == 2)
(len(args) >= 2)
or (len(args) == 1 and "outputs" in kwargs)
or ("inputs" in kwargs and "outputs" in kwargs)
)

def functional_like_constructor(cls):
# This test is permissive. Any constructor that could be passed
# inputs and outputs is accepted. This test triggers Functional
# deserialization when whe know we have a functional config so
# it's OK to try anything that could work.
init_args = inspect.signature(cls.__init__).parameters
funct_init_args = (
("inputs" in init_args and "outputs" in init_args) or
("args" in init_args or "kwargs" in init_args))
return funct_init_args

def strict_functional_like_constructor(cls):
# This test is conservative. Only explcit "inputs" and "outputs"
# arguments with those names, are accepted. This test triggers Functional
# serialization and we want to do that in a subclass only when an explicitly
# functional __init__(inputs, outputs) constructor exists in the subclass.
init_args = inspect.signature(cls.__init__).parameters
funct_init_args = ("inputs" in init_args and "outputs" in init_args)
return funct_init_args

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cleanup of the "is functional-like" logic. See implementation note 4.ii

Comment on lines +224 to +225
def __init__(self, inputs, outputs, *args, param=1, **kwargs):
super().__init__(inputs, outputs, *args, **kwargs)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The carve-out for functional serialization of subclasses of Functional that have a functional-like constructor was constrained to constructors with explicit inputs and outputs arguments, which are the only ones we can test for. See implementation note 4.ii

Comment on lines +254 to +258
# No way to detect that this can be serialized functionnally
# since the graph could have been created inside the custom
# __init__ with the same __init__ args.
config = model.get_config()
self.assertFalse(has_functional_config_keys(config))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In all other cases of subclassing Functional, functional serialization is NOT triggered since it is not possible to detect wether the layer graph is created outside of the class (in which case functional serialization would be useful) and when the graph is created inside of init in which case regular serialization that calls init at the end is enough.

@martin-gorner
Copy link
Contributor Author

martin-gorner commented Apr 1, 2025

Here is a demo Colab showing how a user can patch an CompositeLayer LLM to enable LoRA:
https://colab.research.google.com/drive/1USgG4S9j3XAUqpUZlvhAbLWguV9Gjc28?usp=sharing

@martin-gorner
Copy link
Contributor Author

martin-gorner commented Apr 3, 2025

I added SVF to the demo Colab.

SVF = Singular Value fine-tuning - explanations here, it's an SVD-based variant of LoRA.

It's 36 lines of code in total, 17 for the SVF layer, 7 to patch it into a pretrained model, 12 lines to patch it out, all done with user-level APIs.

(and BTW, the code format failure is not me, it's because of an OpenVino installation warning)

@mattdangerw
Copy link
Member

mattdangerw commented Apr 4, 2025

Hey @martin-gorner ! Probably someone else should review this in more detail, but have been doing some thinking on the UX (and only the UX so far). Still mulling, for now just some questions/thoughts...

Is there a performance issue with using a keras.Model as a layer that never touches the training APIs? Or is it just about simplicity? Basically do the training/saving APIs that are untouched ever get in the way?

I know that just using today's symbols won't cover all of what you are adding here--in particular building a functional from an unknown shape. And we should figure that out. But adding a Sequential as a sub-component to a larger model is already a pattern in a lot of Keras code. Are we saying that's bad with this PR or not really?

Also is there a reason we are pushing sequential and functional style layer construction onto the same symbol? Sequential is separate in our modeling APIs, why is it fused onto a single class in this PR? Seems mildly inconsistent.

# A reusable composite layer
class MyCompositeLayer(CompositeLayer):
    @staticmethod
    def my_layer_fn(inputs):
        x = layers.Dense(5)(inputs)
        return layers.Dense(4)(x)

    def __init__(self, **kwargs):
        super().__init__(MyCompositeLayer.my_layer_fn, **kwargs)

This feels off for a couple reasons. For one, recommending a staticmethod annotated class function is a lot of cognitive load for something we'd want to ideally be a simple as possible.

And it's unclear how this would work for a reusable functional component with a lot of config. If you were writing a block layer with 5 - 10 arguments related to feature sizes, activations, residuals, etc. How would you pass that information to the functional building method?

Lastly, I wonder if we'd want more predictability in names for a subclasses of MyCompositeLayer. The nice thing about subclassed layers today with __init__/build/call methods is you can subclass, chain to super as needed and either augment or overwrite one of these fixed methods and not the others. We lack that property here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants