Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] best_tft.plot_prediction(new_x, new_raw_predictions, idx=idx, show_future_observed=False) is not working for me #1772

Open
ash1407 opened this issue Feb 21, 2025 · 3 comments
Labels
bug Something isn't working

Comments

@ash1407
Copy link

ash1407 commented Feb 21, 2025

Describe the bug

To Reproduce

Build TimeSeriesDataSet

predicting = TimeSeriesDataSet(
pred,
time_idx="time_idx",
target="DEPARTURE_DELAY",
group_ids=['IATA_CODE'],
min_encoder_length=max_prediction_length, # years, length to encode (can be far longer than the decoder length but does not have to be)
max_encoder_length=max_encoder_length, # its in years, max_encoder_length
max_prediction_length=max_prediction_length,
static_categoricals=static_cats,
static_reals=static_reals,
time_varying_known_categoricals=time_varying_cats,
time_varying_known_reals=[
"SCHEDULED_DEPARTURE", "DISTANCE", "SCHEDULED_ARRIVAL",
"SCHEDULED_TIME", "ARRIVAL_TIME", "time_idx"
],
time_varying_unknown_reals=[
"DEPARTURE_TIME", "ARRIVAL_DELAY", "SCHEDULED_TIME",
"WIND_GUST_SPEED", "VIS_DIST", "TMP_CELSIUS", "SLP_PRESSURE",
"DEW_CELSIUS", "CEILING_QUALITY", "CIG_HEIGHT", "WIND_DIRECTION",
"SECURITY_DELAY", "AIRLINE_DELAY", "LATE_AIRCRAFT_DELAY",
"WEATHER_DELAY", "AIR_SYSTEM_DELAY", "DIVERTED", "CANCELLED",
"TAXI_IN", "TAXI_OUT", "WHEELS_OFF",
"WHEELS_ON", "ELAPSED_TIME", "AIR_TIME", "DEPARTURE_DELAY"
],
target_normalizer=GroupNormalizer(groups=['IATA_CODE'], transformation="softplus"), # use softplus and normalize by group
categorical_encoders=categorical_encoders, # Apply NaN encoders
add_relative_time_idx=True,
add_target_scales=True,
add_encoder_length=True,
allow_missing_timesteps=True,
)

pred_dataloader = predicting.to_dataloader(train=True, batch_size=128, num_workers=4)

best_tft.plot_prediction(x, raw_predictions, idx=0, show_future_observed=False);

Image

AttributeError Traceback (most recent call last)
Cell In[231], line 1
----> 1 best_tft.plot_prediction(x, raw_predictions, idx=0, show_future_observed=False);

File c:\Users\qj771f\AppData\Local\Programs\Python\Python310\lib\site-packages\pytorch_forecasting\models\temporal_fusion_transformer_init_.py:719, in TemporalFusionTransformer.plot_prediction(self, x, out, idx, plot_attention, add_loss_to_title, show_future_observed, ax, **kwargs)
717 # add attention on secondary axis
718 if plot_attention:
--> 719 interpretation = self.interpret_output(out.iget(slice(idx, idx + 1)))
720 for f in to_list(fig):
721 ax = f.axes[0]

File c:\Users\qj771f\AppData\Local\Programs\Python\Python310\lib\site-packages\pytorch_forecasting\models\temporal_fusion_transformer_init_.py:597, in TemporalFusionTransformer.interpret_output(self, out, reduction, attention_prediction_horizon)
595 # roll encoder attention (so start last encoder value is on the right)
596 encoder_attention = out["encoder_attention"]
--> 597 shifts = encoder_attention.size(3) - out["encoder_lengths"]
598 new_index = (
599 torch.arange(encoder_attention.size(3), device=encoder_attention.device)[None, None, None].expand_as(
600 encoder_attention
601 )
602 - shifts[:, None, None, None]
603 ) % encoder_attention.size(3)
604 encoder_attention = torch.gather(encoder_attention, dim=3, index=new_index)

AttributeError: 'list' object has no attribute 'size'


**Expected behavior**
<!--
A clear and concise description of what you expected to happen.
-->

**Additional context**
<!--
Add any other context about the problem here.
-->

**Versions**
<details>

<!--
Please run the following code snippet and paste the output here:
 
from sktime import show_versions; show_versions()
-->

</details>

<!-- Thanks for contributing! -->
@ash1407 ash1407 added the bug Something isn't working label Feb 21, 2025
@github-project-automation github-project-automation bot moved this to Needs triage & validation in Bugfixing - pytorch-forecasting Feb 21, 2025
@fkiraly fkiraly changed the title [BUG]best_tft.plot_prediction(new_x, new_raw_predictions, idx=idx, show_future_observed=False) is not working for me [BUG] best_tft.plot_prediction(new_x, new_raw_predictions, idx=idx, show_future_observed=False) is not working for me Feb 22, 2025
@fkiraly
Copy link
Collaborator

fkiraly commented Feb 22, 2025

@ash1407, this example is not fully reproducible since it does not include the data.

Would you be able to produce some dummy data and add all the import, so the bug report is self-contained?

We can then use this for debugging.

@ash1407
Copy link
Author

ash1407 commented Feb 23, 2025

There are some negative data points in data ex- for temperature, departure delay, dew temp
using
pytorch_forecasting: 0.10.1
pandas: 1.5.0
scikit-learn: 1.0.2
matplotlib: 3.7.1
numpy: 1.26.4
seaborn: 0.12.2
pytorch_lightning: 1.9.5
torch: 1.13.1+cu116

<style> </style>
time_idx DATE YEAR MONTH DAY IATA_CODE AIRPORT CITY STATE COUNTRY LATITUDE LONGITUDE DAY_OF_WEEK AIRLINE FLIGHT_NUMBER TAIL_NUMBER ORIGIN_AIRPORT DESTINATION_AIRPORT SCHEDULED_DEPARTURE DEPARTURE_TIME DEPARTURE_DELAY TAXI_OUT WHEELS_OFF SCHEDULED_TIME ELAPSED_TIME AIR_TIME DISTANCE WHEELS_ON TAXI_IN SCHEDULED_ARRIVAL ARRIVAL_TIME ARRIVAL_DELAY DIVERTED CANCELLED CANCELLATION_REASON AIR_SYSTEM_DELAY SECURITY_DELAY AIRLINE_DELAY LATE_AIRCRAFT_DELAY WEATHER_DELAY SCHEDULED_DEPARTURE_NEW_MIN WIND_DIRECTION WIND_SPEED WIND_CHAR WIND_GUST_SPEED WIND_QUALITY CIG_HEIGHT CEILING_QUALITY CIG_TYPE CLOUD_DETECTION VIS_DIST VIS_QUAL VIS_OBSTRUCT VIS_QUALITY2 TMP_CELSIUS TEMP_QUALITY DEW_CELSIUS DEW_FLAG SLP_PRESSURE SLP_QUALITY
0 1/5/2015 2015 1 5 IAH George Bush Intercontinental Airport Houston TX USA 29.98047 95.33972 1 EV 4298 N12160 IAH DTW 1015 1013 2 14 1027 169 155 133 1075 1340 8 1404 1348 16 0 0 B 3 0 5 1 0 1015 40 5 N 36 5 7620 5 M N 16093 5 N 5 28 5 39 5 10387 5
0 1/9/2015 2015 1 9 DAL Dallas Love Field Dallas TX USA 32.84711 96.85177 5 WN 1614 N413WN DAL DEN 1710 1751 41 6 1757 125 112 94 651 1831 12 1815 1843 28 0 0 B 0 0 28 0 0 1710 240 5 N 36 5 7620 5 M N 16093 5 N 5 61 5 17 5 10221 5
0 1/2/2015 2015 1 2 DTW Detroit Metropolitan Airport Detroit MI USA 42.21206 83.34884 5 EV 5086 N852AS DTW MLI 1525 1522 3 10 1532 91 77 64 373 1536 3 1556 1539 17 0 0 B 3 0 5 1 0 1525 240 5 N 36 5 7620 5 M N 16093 5 N 5 61 5 17 5 10221 5
0 1/6/2015 2015 1 6 SLC Salt Lake City International Airport Salt Lake City UT USA 40.78839 111.9778 2 OO 4755 N549CA SLC PHX 1651 1736 45 12 1748 100 84 69 507 1857 3 1831 1900 29 0 0 B 0 0 27 2 0 1651 240 5 N 36 5 7620 5 M N 16093 5 N 5 61 5 17 5 10221 5
0 ######## 2015 1 22 MSP Minneapolis-Saint Paul International Airport Minneapolis MN USA 44.88055 93.21692 4 DL 2628 N896AT MSP DSM 1310 1308 2 10 1318 74 56 43 232 1401 3 1424 1404 20 0 0 B 3 0 5 1 0 1310 220 5 N 62 5 427 5 M N 16093 5 N 5 -56 5 89 5 99999 9
0 1/9/2015 2015 1 9 EWR Newark Liberty International Airport Newark NJ USA 40.6925 74.16866 5 EV 4300 N15572 EWR RIC 1550 1602 12 16 1618 82 87 59 277 1717 12 1712 1729 17 0 0 B 5 0 3 9 0 1550 240 5 N 36 5 7620 5 M N 16093 5 N 5 61 5 17 5 10221 5
0 1/3/2015 2015 1 3 MCO Orlando International Airport Orlando FL USA 28.42889 81.31603 6 B6 2228 N354JB MCO EWR 550 548 2 19 607 149 143 116 937 803 8 819 811 8 0 0 B 3 0 5 1 0 550 90 5 N 15 5 61 5 M N 805 5 N 5 200 5 200 5 99999 9
0 1/5/2015 2015 1 5 LAX Los Angeles International Airport Los Angeles CA USA 33.94254 118.4081 1 AA 1691 N851AA LAX LAS 1755 1806 11 14 1820 75 65 42 236 1902 9 1910 1911 1 0 0 B 3 0 5 1 0 1755 240 5 N 36 5 7620 5 M N 16093 5 N 5 61 5 17 5 10221 5
0 ######## 2015 1 29 SLC Salt Lake City International Airport Salt Lake City UT USA 40.78839 111.9778 4 DL 1074 N367NW SLC MSP 1511 1517 6 12 1529 149 150 133 991 1842 5 1840 1847 7 0 0 B 3 0 5 1 0 1511 240 5 N 36 5 7620 5 M N 16093 5 N 5 61 5 17 5 10221 5
0 ######## 2015 1 17 IAH George Bush Intercontinental Airport Houston TX USA 29.98047 95.33972 6 EV 4184 N12921 IAH LIT 1301 1257 4 16 1313 79 77 56 374 1409 5 1420 1414 6 0 0 B 3 0 5 1 0 1301 220 5 N 46 5 7620 5 M N 16093 5 N 5 200 5 72 5 10175 5
0 ######## 2015 1 22 CLT Charlotte Douglas International Airport Charlotte NC USA 35.21401 80.94313 4 US 1798 N250AY CLT PHL 2010 2010 0 15 2025 97 89 70 449 2135 4 2147 2139 8 0 0 B 3 0 5 1 0 2010 240 5 N 36 5 7620 5 M N 16093 5 N 5 61 5 17 5 10221 5
0 1/2/2015 2015 1 2 DAL Dallas Love Field Dallas TX USA 32.84711 96.85177 5 WN 3300 N638SW DAL STL 1335 1355 20 14 1409 90 89 72 546 1521 3 1505 1524 19 0 0 B 0 0 14 5 0 1335 30 5 N 26 5 213 5 M N 2414 5 N 5 50 5 39 5 99999 9
0 ######## 2015 1 28 ATL Hartsfield-Jackson Atlanta International Airport Atlanta GA USA 33.64044 84.42694 3 DL 1803 N947DN ATL RDU 1120 1120 0 19 1139 80 80 56 356 1235 5 1240 1240 0 0 0 B 3 0 5 1 0 1120 60 5 N 26 5 22000 5 9 N 16093 5 N 5 89 5 50 5 10250 5
0 ######## 2015 1 29 LAX Los Angeles International Airport Los Angeles CA USA 33.94254 118.4081 4 WN 3944 N354SW LAX SFO 1105 1101 4 15 1116 85 74 53 337 1209 6 1230 1215 15 0 0 B 3 0 5 1 0 1105 80 5 N 21 5 5486 5 M N 16093 5 N 5 194 5 67 5 10180 5

@fkiraly
Copy link
Collaborator

fkiraly commented Feb 23, 2025

@ash1407, thanks - could you put that in a single python code block?

You can myke python code blocks with three backticks, like this:

a = 42
print(a)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
Status: Needs triage & validation
Development

No branches or pull requests

2 participants