-
Notifications
You must be signed in to change notification settings - Fork 671
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] best_tft.plot_prediction(new_x, new_raw_predictions, idx=idx, show_future_observed=False)
is not working for me
#1772
Comments
best_tft.plot_prediction(new_x, new_raw_predictions, idx=idx, show_future_observed=False)
is not working for me
@ash1407, this example is not fully reproducible since it does not include the data. Would you be able to produce some dummy data and add all the import, so the bug report is self-contained? We can then use this for debugging. |
There are some negative data points in data ex- for temperature, departure delay, dew temp
|
@ash1407, thanks - could you put that in a single python code block? You can myke python code blocks with three backticks, like this: a = 42
print(a) |
Describe the bug
To Reproduce
Build TimeSeriesDataSet
predicting = TimeSeriesDataSet(
pred,
time_idx="time_idx",
target="DEPARTURE_DELAY",
group_ids=['IATA_CODE'],
min_encoder_length=max_prediction_length, # years, length to encode (can be far longer than the decoder length but does not have to be)
max_encoder_length=max_encoder_length, # its in years, max_encoder_length
max_prediction_length=max_prediction_length,
static_categoricals=static_cats,
static_reals=static_reals,
time_varying_known_categoricals=time_varying_cats,
time_varying_known_reals=[
"SCHEDULED_DEPARTURE", "DISTANCE", "SCHEDULED_ARRIVAL",
"SCHEDULED_TIME", "ARRIVAL_TIME", "time_idx"
],
time_varying_unknown_reals=[
"DEPARTURE_TIME", "ARRIVAL_DELAY", "SCHEDULED_TIME",
"WIND_GUST_SPEED", "VIS_DIST", "TMP_CELSIUS", "SLP_PRESSURE",
"DEW_CELSIUS", "CEILING_QUALITY", "CIG_HEIGHT", "WIND_DIRECTION",
"SECURITY_DELAY", "AIRLINE_DELAY", "LATE_AIRCRAFT_DELAY",
"WEATHER_DELAY", "AIR_SYSTEM_DELAY", "DIVERTED", "CANCELLED",
"TAXI_IN", "TAXI_OUT", "WHEELS_OFF",
"WHEELS_ON", "ELAPSED_TIME", "AIR_TIME", "DEPARTURE_DELAY"
],
target_normalizer=GroupNormalizer(groups=['IATA_CODE'], transformation="softplus"), # use softplus and normalize by group
categorical_encoders=categorical_encoders, # Apply NaN encoders
add_relative_time_idx=True,
add_target_scales=True,
add_encoder_length=True,
allow_missing_timesteps=True,
)
pred_dataloader = predicting.to_dataloader(train=True, batch_size=128, num_workers=4)
best_tft.plot_prediction(x, raw_predictions, idx=0, show_future_observed=False);
AttributeError Traceback (most recent call last)
Cell In[231], line 1
----> 1 best_tft.plot_prediction(x, raw_predictions, idx=0, show_future_observed=False);
File c:\Users\qj771f\AppData\Local\Programs\Python\Python310\lib\site-packages\pytorch_forecasting\models\temporal_fusion_transformer_init_.py:719, in TemporalFusionTransformer.plot_prediction(self, x, out, idx, plot_attention, add_loss_to_title, show_future_observed, ax, **kwargs)
717 # add attention on secondary axis
718 if plot_attention:
--> 719 interpretation = self.interpret_output(out.iget(slice(idx, idx + 1)))
720 for f in to_list(fig):
721 ax = f.axes[0]
File c:\Users\qj771f\AppData\Local\Programs\Python\Python310\lib\site-packages\pytorch_forecasting\models\temporal_fusion_transformer_init_.py:597, in TemporalFusionTransformer.interpret_output(self, out, reduction, attention_prediction_horizon)
595 # roll encoder attention (so start last encoder value is on the right)
596 encoder_attention = out["encoder_attention"]
--> 597 shifts = encoder_attention.size(3) - out["encoder_lengths"]
598 new_index = (
599 torch.arange(encoder_attention.size(3), device=encoder_attention.device)[None, None, None].expand_as(
600 encoder_attention
601 )
602 - shifts[:, None, None, None]
603 ) % encoder_attention.size(3)
604 encoder_attention = torch.gather(encoder_attention, dim=3, index=new_index)
AttributeError: 'list' object has no attribute 'size'
The text was updated successfully, but these errors were encountered: