Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How do I load a GradientBoostedTreesModel to predict on new data? #136

Closed
dempsey-ryan opened this issue Oct 3, 2022 · 10 comments
Closed

Comments

@dempsey-ryan
Copy link

Following the documentation, I save the model as follows:

model.save("deployed_models/regtree/model")

Then, when I tree to load it and predict on new data, I get

ValueError: Exception encountered when calling layer "gradient_boosted_trees_model" (type GradientBoostedTreesModel).

Could not find matching concrete function to call loaded from the SavedModel. Got:

for the few different things I've tried. Here is the relevant part of the training pipeline:

    all_features = ["bw_khz", "qos_rsrp", "qos_rsrq", "qos_rssnr"]
    train_and_val, test_data = train_test_split(all_data, test_size=1-TRAIN_DEV_SPLIT)
    ## create tensors
    # first make smaller dfs with only features and output
    tf_dataset = train_and_val[all_features + ["PrubPct"]]
    tf_test = test_data[all_features + ["PrubPct"]]

    # training tensor
    tf_dataset = tfdf.keras.pd_dataframe_to_tf_dataset(
        tf_dataset,
        label="PrubPct",
        task=tf_core.Task.REGRESSION,
        in_place=False
    )
    # test tensor
    tf_test = tfdf.keras.pd_dataframe_to_tf_dataset(
        tf_test,
        label="PrubPct",
        task=tf_core.Task.REGRESSION,
        in_place=False
    )

    
    ### model stuff
    # model features
    tf_features = list(map(
        lambda x: tfdf.keras.FeatureUsage(x, semantic=tf_core.Semantic.NUMERICAL, discretized=False),
        NUMERICAL_FEATURES
    ))

    # instantiate the model
    model = tfdf.keras.GradientBoostedTreesModel(
        features=tf_features,
        task=tf_core.Task.REGRESSION,
        exclude_non_specified_features=True,
        max_depth=int((1-DEV_SPLIT)*train_and_val.shape[0] - 1),
        validation_ratio=DEV_SPLIT,#default
        loss='SQUARED_ERROR',
        growing_strategy='BEST_FIRST_GLOBAL',
        max_num_nodes=10**6, # best, seemingly
        verbose=2
    )

    # do the training!
    model.fit(x=tf_dataset, callbacks=[tfa.callbacks.TQDMProgressBar()])

    ### done training
    print(model.summary())

    ## print metrics
    # train/dev
    model.compile(metrics=['mse', 'mae'])
    loss, mse, mae = model.evaluate(tf_dataset)
    print(f"train\nloss={loss}, rmse={math.sqrt(mse)}, mae={mae}")
    # test
    loss, mse, mae = model.evaluate(tf_test)
    print(f"unseen data\nloss={loss}, rmse={math.sqrt(mse)}, mae={mae}")

    # predict on unseen test data
    pd.options.mode.chained_assignment = None
    test_data['predicted_prbu'] = model.predict(tf_test)
    test_data['bias'] = test_data['predicted_prbu'] - test_data['PrubPct']
    pd.options.mode.chained_assignment = 'warn'

The documentation seems to imply I can use numpy for predictions, but that doesn't work.

import tensorflow_decision_forests as tfdf
import tensorflow as tf

test_model = tf.keras.models.load_model('deployed_models/regtree/model')
test_model.predict(test_data[all_features].to_numpy())
Could not find matching concrete function to call loaded from the SavedModel. Got:
      Positional arguments (2 total):
        * Tensor("inputs:0", shape=(32, 4), dtype=float32)
        * False
      Keyword arguments: {}

     Expected these arguments to match one of the following 4 option(s):

    Option 1:
      Positional arguments (2 total):
        * {'bw_khz': TensorSpec(shape=(None,), dtype=tf.float32, name='inputs/bw_khz'), 'qos_rssnr': TensorSpec(shape=(None,), dtype=tf.int64, name='inputs/qos_rssnr'), 'qos_rsrp': TensorSpec(shape=(None,), dtype=tf.int64, name='inputs/qos_rsrp'), 'qos_rsrq': TensorSpec(shape=(None,), dtype=tf.int64, name='inputs/qos_rsrq')}
        * False
      Keyword arguments: {}

    Option 2:
      Positional arguments (2 total):
        * {'qos_rssnr': TensorSpec(shape=(None,), dtype=tf.int64, name='inputs/qos_rssnr'), 'bw_khz': TensorSpec(shape=(None,), dtype=tf.float32, name='inputs/bw_khz'), 'qos_rsrq': TensorSpec(shape=(None,), dtype=tf.int64, name='inputs/qos_rsrq'), 'qos_rsrp': TensorSpec(shape=(None,), dtype=tf.int64, name='inputs/qos_rsrp')}
        * True
      Keyword arguments: {}

    Option 3:
      Positional arguments (2 total):
        * {'qos_rsrq': TensorSpec(shape=(None,), dtype=tf.int64, name='qos_rsrq'), 'qos_rssnr': TensorSpec(shape=(None,), dtype=tf.int64, name='qos_rssnr'), 'qos_rsrp': TensorSpec(shape=(None,), dtype=tf.int64, name='qos_rsrp'), 'bw_khz': TensorSpec(shape=(None,), dtype=tf.float32, name='bw_khz')}
        * False
      Keyword arguments: {}

    Option 4:
      Positional arguments (2 total):
        * {'qos_rsrp': TensorSpec(shape=(None,), dtype=tf.int64, name='qos_rsrp'), 'qos_rsrq': TensorSpec(shape=(None,), dtype=tf.int64, name='qos_rsrq'), 'bw_khz': TensorSpec(shape=(None,), dtype=tf.float32, name='bw_khz'), 'qos_rssnr': TensorSpec(shape=(None,), dtype=tf.int64, name='qos_rssnr')}
        * True
      Keyword arguments: {}

    Call arguments received:
      • args=('tf.Tensor(shape=(32, 4), dtype=float32)',)
      • kwargs={'training': 'False'}

I tried converting to Tensor, to list of numpy, list of Tensor, etc. and can't find one that works. What am I getting wrong?

@achoum
Copy link
Collaborator

achoum commented Oct 3, 2022

The error (starting with "Could not find matching concrete function to call loaded from the SavedModel. Got:...") indicates that the examples passed to predict are not structured like the examples given during training.

During training: A dictionary of tensors.
During predict: A single tensor. Tensor("inputs:0", shape=(32, 4), dtype=float32)

You need to make sure to pass the same "things".

You have a couple of options:

Call "predict" on a TensorFlow dataset. In your example, this is test_model.predict(tf_test).

Note: The function pd_dataframe_to_tf_dataset can be used to convert from a Pandas dataframe to a TF Dataset.

Use "predict_on_batch". This function can consume dictionaries of tensor directly (no need to create a TF Dataset). However, this only work for small datasets.

Train your model on a numpy array. This way, you can also call "predict" on a numpy array.

@dempsey-ryan
Copy link
Author

You have a couple of options:

Call "predict" on a TensorFlow dataset. In your example, this is test_model.predict(tf_test).

Note: The function pd_dataframe_to_tf_dataset can be used to convert from a Pandas dataframe to a TF Dataset.

Thanks for the quick reply and the help. I tried doing this but I think I'm still missing something:

tf_newdata = tfdf.keras.pd_dataframe_to_tf_dataset(test_data[all_features], task=tf_core.Task.REGRESSION, in_place=False)
test_model.predict(tf_newdata)
ValueError: Exception encountered when calling layer "gradient_boosted_trees_model" (type GradientBoostedTreesModel).

    Could not find matching concrete function to call loaded from the SavedModel. Got:
      Positional arguments (2 total):
        * {'bw_khz': <tf.Tensor 'inputs:0' shape=(None,) dtype=float32>, 'qos_rsrp': <tf.Tensor 'inputs_1:0' shape=(None,) dtype=int64>, 'qos_rsrq': <tf.Tensor 'inputs_2:0' shape=(None,) dtype=int64>, 'qos_rssnr': <tf.Tensor 'inputs_3:0' shape=(None,) dtype=float32>}
        * False
      Keyword arguments: {}

     Expected these arguments to match one of the following 4 option(s):

    Option 1:
      Positional arguments (2 total):
        * {'qos_rsrp': TensorSpec(shape=(None,), dtype=tf.int64, name='inputs/qos_rsrp'), 'qos_rsrq': TensorSpec(shape=(None,), dtype=tf.int64, name='inputs/qos_rsrq'), 'qos_rssnr': TensorSpec(shape=(None,), dtype=tf.int64, name='inputs/qos_rssnr'), 'bw_khz': TensorSpec(shape=(None,), dtype=tf.float32, name='inputs/bw_khz')}
        * False
      Keyword arguments: {}

    Option 2:
      Positional arguments (2 total):
        * {'qos_rsrp': TensorSpec(shape=(None,), dtype=tf.int64, name='inputs/qos_rsrp'), 'qos_rssnr': TensorSpec(shape=(None,), dtype=tf.int64, name='inputs/qos_rssnr'), 'bw_khz': TensorSpec(shape=(None,), dtype=tf.float32, name='inputs/bw_khz'), 'qos_rsrq': TensorSpec(shape=(None,), dtype=tf.int64, name='inputs/qos_rsrq')}
        * True
      Keyword arguments: {}

    Option 3:
      Positional arguments (2 total):
        * {'qos_rsrp': TensorSpec(shape=(None,), dtype=tf.int64, name='qos_rsrp'), 'bw_khz': TensorSpec(shape=(None,), dtype=tf.float32, name='bw_khz'), 'qos_rssnr': TensorSpec(shape=(None,), dtype=tf.int64, name='qos_rssnr'), 'qos_rsrq': TensorSpec(shape=(None,), dtype=tf.int64, name='qos_rsrq')}
        * False
      Keyword arguments: {}

    Option 4:
      Positional arguments (2 total):
        * {'qos_rsrp': TensorSpec(shape=(None,), dtype=tf.int64, name='qos_rsrp'), 'qos_rsrq': TensorSpec(shape=(None,), dtype=tf.int64, name='qos_rsrq'), 'bw_khz': TensorSpec(shape=(None,), dtype=tf.float32, name='bw_khz'), 'qos_rssnr': TensorSpec(shape=(None,), dtype=tf.int64, name='qos_rssnr')}
        * True
      Keyword arguments: {}

    Call arguments received:
      • args=({'bw_khz': 'tf.Tensor(shape=(None,), dtype=float32)', 'qos_rsrp': 'tf.Tensor(shape=(None,), dtype=int64)', 'qos_rsrq': 'tf.Tensor(shape=(None,), dtype=int64)', 'qos_rssnr': 'tf.Tensor(shape=(None,), dtype=float32)'},)
      • kwargs={'training': 'False'}

@Realvincentyuan
Copy link

Realvincentyuan commented Jul 1, 2023

Hi @achoum

I ran into the same error, after loading model, even running on the same dataset, the error showed up. In my case, I am building a model using the builder way, the sample code of my workflow is as below:

inspector = model.make_inspector()

sample_tree = inspector.extract_tree(tree_idx=0)

# Create some alias
Tree = tfdf.py_tree.tree.Tree
SimpleColumnSpec = tfdf.py_tree.dataspec.SimpleColumnSpec
ColumnType = tfdf.py_tree.dataspec.ColumnType
# Nodes
NonLeafNode = tfdf.py_tree.node.NonLeafNode
LeafNode = tfdf.py_tree.node.LeafNode
# Conditions
NumericalHigherThanCondition = tfdf.py_tree.condition.NumericalHigherThanCondition
CategoricalIsInCondition = tfdf.py_tree.condition.CategoricalIsInCondition
# Leaf values
ProbabilityValue = tfdf.py_tree.value.ProbabilityValue


sample_tree = inspector.extract_tree(tree_idx=0)
print(sample_tree)


# Build a model
model_trial_idx = 1

# Create the model builder

model_trial_idx += 1
model_path = f"/tmp/manual_model/{model_trial_idx}"

builder = tfdf.builder.CARTBuilder(
    path=model_path,
    objective=tfdf.py_tree.objective.ClassificationObjective(
        label="species", classes=["Adelie", "Gentoo", "Chinstrap"])
    )


builder.add_tree(sample_tree)
builder.close()

manual_model = tf.keras.models.load_model(model_path)
manual_model.predict(dataset_tf)

The error is as below:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
[<ipython-input-47-d9c72aa7e6da>](https://localhost:8080/#) in <cell line: 1>()
----> 1 manual_model.predict(dataset_tf)

1 frames
[/usr/local/lib/python3.10/dist-packages/keras/engine/training.py](https://localhost:8080/#) in tf__predict_function(iterator)
     13                 try:
     14                     do_return = True
---> 15                     retval_ = ag__.converted_call(ag__.ld(step_function), (ag__.ld(self), ag__.ld(iterator)), None, fscope)
     16                 except:
     17                     do_return = False

ValueError: in user code:

    File "/usr/local/lib/python3.10/dist-packages/keras/engine/training.py", line 2169, in predict_function  *
        return step_function(self, iterator)
    File "/usr/local/lib/python3.10/dist-packages/keras/engine/training.py", line 2155, in step_function  **
        outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/usr/local/lib/python3.10/dist-packages/keras/engine/training.py", line 2143, in run_step  **
        outputs = model.predict_step(data)
    File "/usr/local/lib/python3.10/dist-packages/keras/engine/training.py", line 2111, in predict_step
        return self(x, training=False)
    File "/usr/local/lib/python3.10/dist-packages/keras/utils/traceback_utils.py", line 70, in error_handler
        raise e.with_traceback(filtered_tb) from None

    ValueError: Could not find matching concrete function to call loaded from the SavedModel. Got:
      Positional arguments (2 total):
        * {'bill_depth_mm': <tf.Tensor 'inputs_2:0' shape=(None,) dtype=float32>,
     'bill_length_mm': <tf.Tensor 'inputs_1:0' shape=(None,) dtype=float32>,
     'body_mass_g': <tf.Tensor 'inputs_4:0' shape=(None,) dtype=float32>,
     'flipper_length_mm': <tf.Tensor 'inputs_3:0' shape=(None,) dtype=float32>,
     'island': <tf.Tensor 'inputs:0' shape=(None,) dtype=string>,
     'sex': <tf.Tensor 'inputs_5:0' shape=(None,) dtype=string>,
     'year': <tf.Tensor 'inputs_6:0' shape=(None,) dtype=int64>}
        * False
      Keyword arguments: {}
    
     Expected these arguments to match one of the following 4 option(s):
    
    Option 1:
      Positional arguments (2 total):
        * {'bill_depth_mm': TensorSpec(shape=(None,), dtype=tf.float32, name='bill_depth_mm'),
     'bill_length_mm': TensorSpec(shape=(None,), dtype=tf.float32, name='bill_length_mm'),
     'flipper_length_mm': TensorSpec(shape=(None,), dtype=tf.float32, name='flipper_length_mm'),
     'island': TensorSpec(shape=(None,), dtype=tf.string, name='island')}
        * True
      Keyword arguments: {}
    
    Option 2:
      Positional arguments (2 total):
        * {'bill_depth_mm': TensorSpec(shape=(None,), dtype=tf.float32, name='bill_depth_mm'),
     'bill_length_mm': TensorSpec(shape=(None,), dtype=tf.float32, name='bill_length_mm'),
     'flipper_length_mm': TensorSpec(shape=(None,), dtype=tf.float32, name='flipper_length_mm'),
     'island': TensorSpec(shape=(None,), dtype=tf.string, name='island')}
        * False
      Keyword arguments: {}
    
    Option 3:
      Positional arguments (2 total):
        * {'bill_depth_mm': TensorSpec(shape=(None,), dtype=tf.float32, name='inputs_bill_depth_mm'),
     'bill_length_mm': TensorSpec(shape=(None,), dtype=tf.float32, name='inputs_bill_length_mm'),
     'flipper_length_mm': TensorSpec(shape=(None,), dtype=tf.float32, name='inputs_flipper_length_mm'),
     'island': TensorSpec(shape=(None,), dtype=tf.string, name='inputs_island')}
        * True
      Keyword arguments: {}
    
    Option 4:
      Positional arguments (2 total):
        * {'bill_depth_mm': TensorSpec(shape=(None,), dtype=tf.float32, name='inputs_bill_depth_mm'),
     'bill_length_mm': TensorSpec(shape=(None,), dtype=tf.float32, name='inputs_bill_length_mm'),
     'flipper_length_mm': TensorSpec(shape=(None,), dtype=tf.float32, name='inputs_flipper_length_mm'),
     'island': TensorSpec(shape=(None,), dtype=tf.string, name='inputs_island')}
        * False
      Keyword arguments: {}

I� added more context in #184

Appendix

@piotrlaczkowski
Copy link

Do you happen to have any updates on this matter?
I have the same error on all versions of TFDF (1.2, 1.3, 1.4,1.5) with corresponding TF versions 2.11, 2.12, 2.13, and different models GradientBoostedTreesModel and RandomForestModel... In my case, I am using datasets for both train and predictions (I have tried additional open-source data as well), and nothing works.

I think it is a major issue if we can't save and then reload safely any of the models.

Please help :)

@dempsey-ryan
Copy link
Author

@piotrlaczkowski @Realvincentyuan I apologize for not updating sooner. I have not worked on this project for quite some time. However, @achoum 's solution #(3) did in fact work. I tried both training and predicting on numpy arrays and had no more issues. See sample code below:

"""
train_dataset and test_dataset are both Pandas DataFrames,
all_features is a list of strings, OUTPUT_COL is a string,
and `TREE_DIR`"/model" is where I saved the model after training.
"""
model.fit(
    x=train_dataset[all_features].to_numpy(), y=train_dataset[OUTPUT_COL].to_numpy(),
    callbacks=[tfa.callbacks.TQDMProgressBar()]
)

...
trees_model = tf.keras.models.load_model(os.path.join(TREE_DIR, "model"))
tfdf_features = test_dataset[all_features].to_numpy()
predictions = trees_model.predict(tfdf_features)

@rstz
Copy link
Collaborator

rstz commented Jul 26, 2023

@dempsey-ryan Thank you for updating the issue
@piotrlaczkowski Does this help in your case? If there's still an issue, a colab to reproduce this (or a full repro code snippet) would be really helpful

@piotrlaczkowski
Copy link

Hey @rstz ,
Thank you for your swift response and your interest in the problem. I think we are talking about the serialization/deserialization problem (and not necessarily about data format itself).
I've put together a simple Collab: https://colab.research.google.com/drive/19sepbkGXwM8lI6fZuovvYRAVSiYCNiKl?usp=sharing
so that you may have more details and play around. I've tested all this with all versions (as mentioned in my previous comment, with various systems as well) -> can't figure out what is wrong. I hope this will give much more feedback and we will be able to quickly solve the problem (if there is any).

Thank you !!

@rstz
Copy link
Collaborator

rstz commented Jul 28, 2023

Thank you for the repro. This does indeed look like a problem. I can't quite put my finger on it yet, so I'll open a new (more specific) issue for this.

@piotrlaczkowski
Copy link

piotrlaczkowski commented Jul 28, 2023 via email

@achoum
Copy link
Collaborator

achoum commented Aug 2, 2023

A loaded model is more restrictive on the data it can consume. #187 shows how to update the loading model input to make inference possible.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants