Skip to content

API Reference

microfuel ¤

PATH_ROOT module-attribute ¤

PATH_ROOT = Path(__file__).parent.parent.parent

PATH_DATA module-attribute ¤

PATH_DATA = PATH_ROOT / 'data'

PATH_DATA_RAW module-attribute ¤

PATH_DATA_RAW = PATH_DATA / 'raw'

PATH_PREPROCESSED module-attribute ¤

PATH_PREPROCESSED = PATH_DATA / 'preprocessed'

PATH_PLOTS_OUTPUT module-attribute ¤

PATH_PLOTS_OUTPUT = PATH_DATA / 'plots'

PATH_CHECKPOINTS module-attribute ¤

PATH_CHECKPOINTS = PATH_DATA / 'checkpoints'

PATH_PREDICTIONS module-attribute ¤

PATH_PREDICTIONS = PATH_DATA / 'predictions'

FlightId module-attribute ¤

FlightId: TypeAlias = str

Unique flight identifier: prc_{}

SegmentId module-attribute ¤

SegmentId: TypeAlias = int

AircraftType module-attribute ¤

AircraftType: TypeAlias = Literal[
    "A20N",
    "A320",
    "A359",
    "B788",
    "B738",
    "A332",
    "A21N",
    "A321",
    "B789",
    "B77W",
    "A333",
    "B772",
    "B744",
    "B737",
    "B739",
    "B38M",
    "A319",
    "A306",
    "A388",
    "B752",
    "B748",
    "B77L",
    "B763",
    "MD11",
    "B39M",
    "A318",
]

AIRCRAFT_TYPES module-attribute ¤

AIRCRAFT_TYPES: tuple[AircraftType, ...] = get_args(
    AircraftType
)

AirportIcao module-attribute ¤

AirportIcao: TypeAlias = str

Partition module-attribute ¤

Partition: TypeAlias = Literal[
    "phase1", "phase1_rank", "phase2_rank"
]

Split module-attribute ¤

Split: TypeAlias = Literal['train', 'validation']

SPLITS module-attribute ¤

SPLITS: tuple[Split, ...] = get_args(Split)

deg2rad module-attribute ¤

deg2rad = isqx.convert(isqx.DEG, isqx.RAD)

ft2m module-attribute ¤

ft2m = isqx.convert(isqx.usc.FT, isqx.M)

knot2mps module-attribute ¤

knot2mps = isqx.convert(isqx.usc.KNOT, isqx.M_PERS)

fpm2mps module-attribute ¤

fpm2mps = isqx.convert(
    isqx.usc.FT * isqx.MIN**-1, isqx.M_PERS
)

Coordinate2D ¤

Bases: NamedTuple, Generic[_T]

Source code in src/microfuel/__init__.py
60
61
62
class Coordinate2D(NamedTuple, Generic[_T]):
    lng: Annotated[_T, isqx.LONGITUDE(isqx.DEG)]
    lat: Annotated[_T, isqx.LATITUDE(isqx.DEG)]

lng instance-attribute ¤

lng: Annotated[_T, isqx.LONGITUDE(isqx.DEG)]

lat instance-attribute ¤

lat: Annotated[_T, isqx.LATITUDE(isqx.DEG)]

dataloader ¤

logger module-attribute ¤

logger = logging.getLogger(__name__)

SequenceInfo module-attribute ¤

SequenceInfo = namedtuple(
    "SequenceInfo",
    [
        "flight_indices",
        "segment_indices_relative",
        "target",
        "segment_id",
        "aircraft_type_idx",
        "duration_s",
        "flight_id",
    ],
)

Sequence module-attribute ¤

Sequence = namedtuple(
    "Sequence",
    [
        "features_flight",
        "features_segment",
        "target",
        "segment_id",
        "aircraft_type_idx",
        "duration_s",
        "flight_id",
    ],
)

VarlenBatch module-attribute ¤

VarlenBatch = namedtuple(
    "VarlenBatch",
    [
        "x_flight",
        "cu_seqlens_flight",
        "x_segment",
        "cu_seqlens_segment",
        "y",
        "segment_ids",
        "aircraft_type_idx",
        "durations",
    ],
)

AC_TYPE_ALIASES module-attribute ¤

AC_TYPE_ALIASES = {'B734': 'B737'}

VarlenDataset ¤

Bases: Dataset

Source code in src/microfuel/dataloader.py
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
class VarlenDataset(Dataset):
    def __init__(self, partition: Partition, split: Split | None):
        if split:
            splits = preprocessed.load_splits(partition)
            segment_ids = splits[split]
            flight_ids = (
                raw.scan_fuel(partition)
                .filter(pl.col("idx").is_in(segment_ids))
                .select("flight_id")
                .unique()
                .collect()["flight_id"]
                .to_list()
            )
        else:
            segment_ids = None
            flight_ids = (
                raw.scan_fuel(partition)
                .select("flight_id")
                .unique()
                .collect()["flight_id"]
                .to_list()
            )

        # always get train stats for submission
        self.stats = preprocessed.load_standardisation_stats("phase1")
        self.ac_type_vocab = {ac_type: i for i, ac_type in enumerate(AIRCRAFT_TYPES)}

        self.all_features, self.sequences = _prepare_tensors(
            partition, flight_ids, segment_ids, self.stats, self.ac_type_vocab
        )

        counts = Counter(s.aircraft_type_idx for s in self.sequences)
        self.class_counts = torch.tensor([counts[i] for i in range(len(self.ac_type_vocab))])

    def __len__(self) -> int:
        return len(self.sequences)

    def __getitem__(self, idx: int) -> Sequence:
        seq_info = self.sequences[idx]
        flight_start_abs, flight_end_abs = seq_info.flight_indices
        segment_start_rel, segment_end_rel = seq_info.segment_indices_relative

        features_flight = self.all_features[flight_start_abs:flight_end_abs]
        features_segment = features_flight[segment_start_rel:segment_end_rel]

        return Sequence(
            features_flight=features_flight,
            features_segment=features_segment,
            target=seq_info.target,
            segment_id=seq_info.segment_id,
            aircraft_type_idx=seq_info.aircraft_type_idx,
            duration_s=seq_info.duration_s,
            flight_id=seq_info.flight_id,
        )
stats instance-attribute ¤
ac_type_vocab instance-attribute ¤
ac_type_vocab = {
    ac_type: i
    for i, ac_type in (enumerate(AIRCRAFT_TYPES))
}
class_counts instance-attribute ¤
class_counts = torch.tensor(
    [
        (counts[i])
        for i in (range(len(self.ac_type_vocab)))
    ]
)
__init__ ¤
__init__(partition: Partition, split: Split | None)
Source code in src/microfuel/dataloader.py
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
def __init__(self, partition: Partition, split: Split | None):
    if split:
        splits = preprocessed.load_splits(partition)
        segment_ids = splits[split]
        flight_ids = (
            raw.scan_fuel(partition)
            .filter(pl.col("idx").is_in(segment_ids))
            .select("flight_id")
            .unique()
            .collect()["flight_id"]
            .to_list()
        )
    else:
        segment_ids = None
        flight_ids = (
            raw.scan_fuel(partition)
            .select("flight_id")
            .unique()
            .collect()["flight_id"]
            .to_list()
        )

    # always get train stats for submission
    self.stats = preprocessed.load_standardisation_stats("phase1")
    self.ac_type_vocab = {ac_type: i for i, ac_type in enumerate(AIRCRAFT_TYPES)}

    self.all_features, self.sequences = _prepare_tensors(
        partition, flight_ids, segment_ids, self.stats, self.ac_type_vocab
    )

    counts = Counter(s.aircraft_type_idx for s in self.sequences)
    self.class_counts = torch.tensor([counts[i] for i in range(len(self.ac_type_vocab))])
__len__ ¤
__len__() -> int
Source code in src/microfuel/dataloader.py
212
213
def __len__(self) -> int:
    return len(self.sequences)
__getitem__ ¤
__getitem__(idx: int) -> Sequence
Source code in src/microfuel/dataloader.py
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
def __getitem__(self, idx: int) -> Sequence:
    seq_info = self.sequences[idx]
    flight_start_abs, flight_end_abs = seq_info.flight_indices
    segment_start_rel, segment_end_rel = seq_info.segment_indices_relative

    features_flight = self.all_features[flight_start_abs:flight_end_abs]
    features_segment = features_flight[segment_start_rel:segment_end_rel]

    return Sequence(
        features_flight=features_flight,
        features_segment=features_segment,
        target=seq_info.target,
        segment_id=seq_info.segment_id,
        aircraft_type_idx=seq_info.aircraft_type_idx,
        duration_s=seq_info.duration_s,
        flight_id=seq_info.flight_id,
    )

collate_fn ¤

collate_fn(batch_sequences: list[Sequence]) -> VarlenBatch
Source code in src/microfuel/dataloader.py
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
def collate_fn(batch_sequences: list[Sequence]) -> VarlenBatch:
    lengths_flight = [len(seq.features_flight) for seq in batch_sequences]
    lengths_segment = [len(seq.features_segment) for seq in batch_sequences]

    x_flight = torch.cat([seq.features_flight for seq in batch_sequences], dim=0)
    x_segment = torch.cat([seq.features_segment for seq in batch_sequences], dim=0)
    y = torch.tensor([seq.target for seq in batch_sequences], dtype=torch.float32).unsqueeze(1)

    cu_seqlens_flight = torch.from_numpy(np.cumsum([0, *lengths_flight], dtype=np.int32))
    cu_seqlens_segment = torch.from_numpy(np.cumsum([0, *lengths_segment], dtype=np.int32))

    segment_ids = torch.tensor([seq.segment_id for seq in batch_sequences], dtype=torch.int32)
    aircraft_type_idx = torch.tensor(
        [seq.aircraft_type_idx for seq in batch_sequences], dtype=torch.long
    )
    durations = torch.tensor([seq.duration_s for seq in batch_sequences], dtype=torch.float32)

    return VarlenBatch(
        x_flight=x_flight,
        cu_seqlens_flight=cu_seqlens_flight,
        x_segment=x_segment,
        cu_seqlens_segment=cu_seqlens_segment,
        y=y,
        segment_ids=segment_ids,
        aircraft_type_idx=aircraft_type_idx,
        durations=durations,
    )

datasets ¤

preprocessed ¤

logger module-attribute ¤
logger = logging.getLogger(__name__)
CoreFeature module-attribute ¤
CoreFeature = Literal[
    "altitude", "groundspeed", "vertical_rate"
]
WEATHER_FEATURES module-attribute ¤
WEATHER_FEATURES = []
StateFeature module-attribute ¤
StateFeature = Literal[CoreFeature]
STATE_FEATURES module-attribute ¤
STATE_FEATURES = get_args(StateFeature)
FlightFeature module-attribute ¤
FlightFeature = Literal[
    "flight_progress", "flight_duration"
]
FLIGHT_FEATURES module-attribute ¤
FLIGHT_FEATURES = get_args(FlightFeature)
MODEL_INPUT_FEATURES module-attribute ¤
MODEL_INPUT_FEATURES: list[str] = [
    *FLIGHT_FEATURES,
    *STATE_FEATURES,
]
SmoothResult module-attribute ¤
SmoothResult = namedtuple(
    "SmoothResult", ["val", "val_d", "var_val", "var_val_d"]
)
Stats module-attribute ¤
Stat ¤

Bases: TypedDict

Source code in src/microfuel/datasets/preprocessed.py
984
985
986
class Stat(TypedDict):
    mean: float
    std: float
mean instance-attribute ¤
mean: float
std instance-attribute ¤
std: float
IteratorData ¤

Bases: NamedTuple

Source code in src/microfuel/datasets/preprocessed.py
1095
1096
1097
class IteratorData(NamedTuple):
    segments_df: pl.DataFrame
    traj_lf: pl.LazyFrame
segments_df instance-attribute ¤
segments_df: pl.DataFrame
traj_lf instance-attribute ¤
traj_lf: pl.LazyFrame
TrajectoryInfo ¤

Bases: TypedDict

Source code in src/microfuel/datasets/preprocessed.py
1140
1141
1142
1143
1144
1145
1146
1147
1148
class TrajectoryInfo(TypedDict):
    idx: int
    flight_id: str
    start: datetime
    end: datetime
    fuel_kg: float
    takeoff: datetime
    landed: datetime
    aircraft_type: str
idx instance-attribute ¤
idx: int
flight_id instance-attribute ¤
flight_id: str
start instance-attribute ¤
start: datetime
end instance-attribute ¤
end: datetime
fuel_kg instance-attribute ¤
fuel_kg: float
takeoff instance-attribute ¤
takeoff: datetime
landed instance-attribute ¤
landed: datetime
aircraft_type instance-attribute ¤
aircraft_type: str
Trajectory ¤

Bases: NamedTuple

Source code in src/microfuel/datasets/preprocessed.py
1151
1152
1153
class Trajectory(NamedTuple):
    features_df: pl.DataFrame
    info: TrajectoryInfo
features_df instance-attribute ¤
features_df: pl.DataFrame
info instance-attribute ¤
TrajectoryIterator ¤

Yields the entire flight trajectory for each segment as polars DataFrames.

Source code in src/microfuel/datasets/preprocessed.py
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
class TrajectoryIterator:
    """Yields the entire flight trajectory for each segment as polars DataFrames."""

    def __init__(
        self,
        partition: Partition,
        *,
        segment_ids: Collection[SegmentId] | None = None,
        shuffle_seed: int | None = None,
        stats: Stats | None = None,
        start_to_end_only: bool = False,
        path_base: Path = PATH_PREPROCESSED,
    ):
        """
        :param start_to_end_only: if False, yields the entire materialised flight trajectory.
            Note that collecting this iterator will use a lot of memory due to duplicates!
            Prefer using the torch iterator instead.
        """
        it_data = prepare_iterator_data(partition, segment_ids, stats, path_base)
        self.traj_lf = it_data.traj_lf
        self.segment_infos: list[TrajectoryInfo] = it_data.segments_df.to_dicts()  # type: ignore
        self.start_to_end_only = start_to_end_only
        self.stats = stats

        self.segments_by_flight: dict[FlightId, list[TrajectoryInfo]] = {}
        for info in self.segment_infos:
            self.segments_by_flight.setdefault(info["flight_id"], []).append(info)

        self.flight_ids_to_iterate = list(self.segments_by_flight.keys())
        if shuffle_seed is not None:
            rng = np.random.default_rng(shuffle_seed)
            rng.shuffle(self.flight_ids_to_iterate)

    def __len__(self) -> int:
        return len(self.segment_infos)

    def __iter__(self) -> Iterator[Trajectory]:
        for flight_id in self.flight_ids_to_iterate:
            flight_traj_df = self.traj_lf.filter(pl.col("flight_id") == flight_id).collect()

            for segment_info in self.segments_by_flight[flight_id]:
                if self.start_to_end_only:
                    start_ts = segment_info["start"]
                    end_ts = segment_info["end"]

                    start_idx, end_idx = find_segment_indices(
                        flight_traj_df["timestamp"].to_numpy(),
                        np.datetime64(start_ts.isoformat()),
                        np.datetime64(end_ts.isoformat()),
                        xp=np,
                    )
                    segment_traj_df = flight_traj_df[start_idx:end_idx]
                    if segment_traj_df.height < 2:
                        logger.error(
                            f"skipping {flight_id}: found < 2 datapoints for segment "
                            f"({start_ts} - {end_ts}): {start_idx}..={end_idx}"
                        )
                        continue
                else:
                    segment_traj_df = flight_traj_df

                yield Trajectory(
                    features_df=segment_traj_df,
                    info=segment_info,
                )
traj_lf instance-attribute ¤
traj_lf = it_data.traj_lf
segment_infos instance-attribute ¤
segment_infos: list[TrajectoryInfo] = (
    it_data.segments_df.to_dicts()
)
start_to_end_only instance-attribute ¤
start_to_end_only = start_to_end_only
stats instance-attribute ¤
stats = stats
segments_by_flight instance-attribute ¤
segments_by_flight: dict[
    FlightId, list[TrajectoryInfo]
] = {}
flight_ids_to_iterate instance-attribute ¤
flight_ids_to_iterate = list(self.segments_by_flight.keys())
__init__ ¤
__init__(
    partition: Partition,
    *,
    segment_ids: Collection[SegmentId] | None = None,
    shuffle_seed: int | None = None,
    stats: Stats | None = None,
    start_to_end_only: bool = False,
    path_base: Path = PATH_PREPROCESSED,
)

Parameters:

Name Type Description Default
start_to_end_only bool

if False, yields the entire materialised flight trajectory. Note that collecting this iterator will use a lot of memory due to duplicates! Prefer using the torch iterator instead.

False
Source code in src/microfuel/datasets/preprocessed.py
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
def __init__(
    self,
    partition: Partition,
    *,
    segment_ids: Collection[SegmentId] | None = None,
    shuffle_seed: int | None = None,
    stats: Stats | None = None,
    start_to_end_only: bool = False,
    path_base: Path = PATH_PREPROCESSED,
):
    """
    :param start_to_end_only: if False, yields the entire materialised flight trajectory.
        Note that collecting this iterator will use a lot of memory due to duplicates!
        Prefer using the torch iterator instead.
    """
    it_data = prepare_iterator_data(partition, segment_ids, stats, path_base)
    self.traj_lf = it_data.traj_lf
    self.segment_infos: list[TrajectoryInfo] = it_data.segments_df.to_dicts()  # type: ignore
    self.start_to_end_only = start_to_end_only
    self.stats = stats

    self.segments_by_flight: dict[FlightId, list[TrajectoryInfo]] = {}
    for info in self.segment_infos:
        self.segments_by_flight.setdefault(info["flight_id"], []).append(info)

    self.flight_ids_to_iterate = list(self.segments_by_flight.keys())
    if shuffle_seed is not None:
        rng = np.random.default_rng(shuffle_seed)
        rng.shuffle(self.flight_ids_to_iterate)
__len__ ¤
__len__() -> int
Source code in src/microfuel/datasets/preprocessed.py
1189
1190
def __len__(self) -> int:
    return len(self.segment_infos)
__iter__ ¤
__iter__() -> Iterator[Trajectory]
Source code in src/microfuel/datasets/preprocessed.py
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
def __iter__(self) -> Iterator[Trajectory]:
    for flight_id in self.flight_ids_to_iterate:
        flight_traj_df = self.traj_lf.filter(pl.col("flight_id") == flight_id).collect()

        for segment_info in self.segments_by_flight[flight_id]:
            if self.start_to_end_only:
                start_ts = segment_info["start"]
                end_ts = segment_info["end"]

                start_idx, end_idx = find_segment_indices(
                    flight_traj_df["timestamp"].to_numpy(),
                    np.datetime64(start_ts.isoformat()),
                    np.datetime64(end_ts.isoformat()),
                    xp=np,
                )
                segment_traj_df = flight_traj_df[start_idx:end_idx]
                if segment_traj_df.height < 2:
                    logger.error(
                        f"skipping {flight_id}: found < 2 datapoints for segment "
                        f"({start_ts} - {end_ts}): {start_idx}..={end_idx}"
                    )
                    continue
            else:
                segment_traj_df = flight_traj_df

            yield Trajectory(
                features_df=segment_traj_df,
                info=segment_info,
            )
make_splits ¤
make_splits(
    partition: Partition,
    train_split: float = 0.8,
    seed: int = 13,
    *,
    path_base: Path = PATH_PREPROCESSED,
    max_bins: int = 30,
    min_samples_for_binning: int = 2,
)
Source code in src/microfuel/datasets/preprocessed.py
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
def make_splits(
    partition: Partition,
    train_split: float = 0.8,
    seed: int = 13,
    *,
    path_base: Path = PATH_PREPROCESSED,
    max_bins: int = 30,
    min_samples_for_binning: int = 2,
):  # TODO: allow k fold stratified splits
    path_base.mkdir(exist_ok=True, parents=True)
    flight_list_lf = raw.scan_flight_list(partition)
    fuel_lf = raw.scan_fuel(partition)

    segments_df = (
        fuel_lf.with_columns(
            (pl.col("end") - pl.col("start")).dt.total_seconds().alias("duration_s")
        )
        .join(flight_list_lf.select("flight_id", "aircraft_type"), on="flight_id")
        .select("idx", "aircraft_type", "duration_s")
        .sort("idx")
        .collect()
    )
    logger.info(f"found {len(segments_df)} segments with fuel data in `{partition}`")

    bin_boundaries_data = []

    def add_duration_bin(group_df: pl.DataFrame) -> pl.DataFrame:
        ac_type = group_df["aircraft_type"][0]
        duration_series = group_df["duration_s"]
        n_samples = duration_series.len()
        assert n_samples > 0

        duration_quantiles = (
            min(max_bins, int(1 + np.log2(n_samples)))
            if n_samples >= min_samples_for_binning
            else 1
        )

        quantile_points = np.linspace(0, 1, duration_quantiles + 1)
        breaks_set: set[float] = set()
        for q in quantile_points:
            b = duration_series.quantile(q, interpolation="linear")
            assert b is not None
            breaks_set.add(b)
        breaks = sorted(breaks_set)

        if len(breaks) < 2:
            min_val = float(duration_series.min() or 0)  # type: ignore
            max_val = float(duration_series.max() or 1)  # type: ignore
            breaks = [min_val, max_val] if min_val != max_val else [min_val, min_val + 1]

        # last break should be inclusive of the max value
        max_dur = float(duration_series.max())  # type: ignore
        if max_dur is not None and breaks[-1] < max_dur:
            breaks[-1] = max_dur

        labels = [f"d_q{i}" for i in range(len(breaks) - 1)]
        for i, label in enumerate(labels):
            bin_boundaries_data.append(
                {
                    "aircraft_type": ac_type,
                    "duration_bin": label,
                    "lower_bound": breaks[i],
                    "upper_bound": breaks[i + 1],
                }
            )

        bin_expr = pl.when(pl.col("duration_s") <= breaks[1]).then(pl.lit(labels[0]))
        for i in range(2, len(breaks) - 1):
            bin_expr = bin_expr.when(pl.col("duration_s") <= breaks[i]).then(pl.lit(labels[i - 1]))
        bin_expr = bin_expr.otherwise(pl.lit(labels[-1]))

        return group_df.with_columns(bin_expr.alias("duration_bin"))

    segments_binned_df = segments_df.group_by("aircraft_type", maintain_order=True).map_groups(
        add_duration_bin
    )
    bin_boundaries_df = pl.DataFrame(bin_boundaries_data)

    stratify_cols = ["aircraft_type", "duration_bin"]
    n_train_samples_expr = pl.max_horizontal(1, (pl.count() * train_split).floor())
    train_df = segments_binned_df.filter(
        pl.int_range(0, pl.count()).shuffle(seed=seed).over(stratify_cols)
        < n_train_samples_expr.over(stratify_cols)
    )

    train_segment_ids_set = set(train_df["idx"].to_list())
    all_segment_ids_set = set(segments_binned_df["idx"].to_list())
    validation_segment_ids_set = all_segment_ids_set - train_segment_ids_set

    segment_ids_train = sorted(list(train_segment_ids_set))
    segment_ids_validation = sorted(list(validation_segment_ids_set))

    logger.info(
        f"stratified split by {stratify_cols}: {len(segment_ids_train)} train, {len(segment_ids_validation)} validation"
    )

    train_counts_df = train_df.group_by(stratify_cols).len().rename({"len": "train_count"})
    validation_df = segments_binned_df.filter(pl.col("idx").is_in(validation_segment_ids_set))
    validation_counts_df = (
        validation_df.group_by(stratify_cols).len().rename({"len": "validation_count"})
    )

    all_groups_df = segments_binned_df.select(stratify_cols).unique().sort(stratify_cols)
    combined_counts_df = (
        all_groups_df.join(train_counts_df, on=stratify_cols, how="left")
        .join(validation_counts_df, on=stratify_cols, how="left")
        .fill_null(0)
    )

    logging_df = combined_counts_df.join(bin_boundaries_df, on=stratify_cols, how="left")

    logger.info("split counts by stratification groups:")
    _ac_types: set[str] = set()
    for row in logging_df.sort(["aircraft_type", "duration_bin"]).iter_rows(named=True):
        ac_type = t if (t := row["aircraft_type"]) not in _ac_types else ""
        _ac_types.add(row["aircraft_type"])
        lower = row["lower_bound"]
        upper = row["upper_bound"]
        train_count = row["train_count"]
        validation_count = row["validation_count"]
        total = train_count + validation_count
        train_pct = train_count / total if total > 0 else 0
        duration_str = f"({lower or 0:.0f}s-{upper or 0:.0f}s]"
        logger.info(
            f"  {ac_type:<5}{duration_str:<14}: {train_count:>5}/{validation_count:>5} ({train_pct:5.1%})"
        )

    splits: dict[Split, list[SegmentId]] = {
        "train": segment_ids_train,
        "validation": segment_ids_validation,
    }
    output_path = path_base / f"splits_{partition}.json"
    with open(output_path, "w") as f:
        json.dump(splits, f)
    logger.info(f"wrote splits to {output_path}")
find_segment_indices ¤
find_segment_indices(
    timestamps, start_time, end_time, *, xp=np
)
Source code in src/microfuel/datasets/preprocessed.py
206
207
208
209
def find_segment_indices(timestamps, start_time, end_time, *, xp=np):
    start_idx = xp.searchsorted(timestamps, start_time, side="left")
    end_idx = xp.searchsorted(timestamps, end_time, side="right")
    return start_idx, end_idx
smooth_time_series ¤
smooth_time_series(
    values,
    dts_s,
    process_noise_variances: tuple[float, float],
    observation_noise_variance: float,
    gap_threshold: float = 30.0,
) -> SmoothResult

Applies a Kalman filter and RTS smoother to a 1D time series, handling large gaps.

Assumes the time series follow a Constant Velocity (CV) system: \(x_k = F x_{k-1} + w_k\).

Parameters:

Name Type Description Default
process_noise_variances tuple[float, float]

(pos, vel) variances for the model's state transition noise (Q).

required
observation_noise_variance float

variance for the measurement noise (R).

required
gap_threshold float

time gap (in seconds) above which to split the time series into chunks.

30.0
Source code in src/microfuel/datasets/preprocessed.py
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
def smooth_time_series(
    values,
    dts_s,
    process_noise_variances: tuple[float, float],
    observation_noise_variance: float,
    gap_threshold: float = 30.0,
) -> SmoothResult:
    """Applies a Kalman filter and RTS smoother to a 1D time series, handling large gaps.

    Assumes the time series follow a Constant Velocity (CV) system:
    $x_k = F x_{k-1} + w_k$.

    :param process_noise_variances: (pos, vel) variances for the model's state transition noise (Q).
    :param observation_noise_variance: variance for the measurement noise (R).
    :param gap_threshold: time gap (in seconds) above which to split the time series into chunks.
    """
    gap_indices = np.where(dts_s > gap_threshold)[0]
    chunk_boundaries = np.concatenate(([0], gap_indices + 1, [len(values)]))

    all_smoothed_values = np.full_like(values, np.nan)
    all_smoothed_derivatives = np.full_like(values, np.nan)
    all_smoothed_value_variances = np.full_like(values, np.nan)
    all_smoothed_derivative_variances = np.full_like(values, np.nan)

    transition_matrix_template = np.array([[1.0, 0.0], [0.0, 1.0]])
    observation_matrix = np.array([[1.0, 0.0]])
    process_noise_covariance = np.diag(np.array(process_noise_variances, dtype=np.float64))
    observation_noise_covariance = np.array([[observation_noise_variance]])

    for i in range(len(chunk_boundaries) - 1):
        start, end = chunk_boundaries[i], chunk_boundaries[i + 1]
        chunk_values = values[start:end]
        chunk_dts = dts_s[start : end - 1]

        if len(chunk_values) < 2:
            continue

        first_valid_idx = np.where(~np.isnan(chunk_values))[0]
        if len(first_valid_idx) == 0:
            continue
        initial_value = chunk_values[first_valid_idx[0]]
        initial_state_mean = np.array([initial_value, 0.0])
        initial_state_covariance = np.eye(2) * 1e5

        filtered_means, filtered_covs = _kalman_filter(
            measurements=chunk_values,
            dts=chunk_dts,
            initial_state_mean=initial_state_mean,
            initial_state_covariance=initial_state_covariance,
            transition_matrix_fn_val=transition_matrix_template,
            observation_matrix=observation_matrix,
            process_noise_covariance=process_noise_covariance,
            observation_noise_covariance=observation_noise_covariance,
        )

        smoothed_means, smoothed_covs = _rts_smoother_numba(
            filtered_means,
            filtered_covs,
            chunk_dts,
            transition_matrix_template,
            process_noise_covariance,
        )

        all_smoothed_values[start:end] = smoothed_means[:, 0]
        all_smoothed_derivatives[start:end] = smoothed_means[:, 1]
        all_smoothed_value_variances[start:end] = smoothed_covs[:, 0, 0]
        all_smoothed_derivative_variances[start:end] = smoothed_covs[:, 1, 1]

    return SmoothResult(
        all_smoothed_values,
        all_smoothed_derivatives,
        all_smoothed_value_variances,
        all_smoothed_derivative_variances,
    )
make_trajectories ¤
make_trajectories(
    partition: Partition,
    seed: int = 13,
    *,
    path_base: Path = PATH_PREPROCESSED,
    altitude_max: Annotated[
        float, isqx.aerospace.PRESSURE_ALTITUDE(isqx.M)
    ] = ft2m(50000),
    speed_max: Annotated[
        float, isqx.SPEED(isqx.M_PERS)
    ] = knot2mps(800),
    vertical_speed_max: Annotated[
        float, isqx.aerospace.VS(isqx.M_PERS)
    ] = fpm2mps(8000),
    track_rate_max: Annotated[float, isqx.RAD_PERS] = 0.003,
    plot_every_n_flights: int | None = None,
)

Creates train/validation split of preprocessed trajectories.

Handles the alignment of asynchronous data sources:

  1. Flight List: [takeoff, landing] constraints.
  2. Fuel Data: segment boundaries.
  3. ADS-B + ACARS: raw state observations.

It produces the standard state vector \(x_t\) required by microfuel.model.FuelBurnPredictor.

Everything related to segments (e.g. whether a particular state vector is within [start, end]) should be handled elsewhere. This function processes the entire trajectory.

Source code in src/microfuel/datasets/preprocessed.py
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
def make_trajectories(
    partition: Partition,
    seed: int = 13,
    *,
    path_base: Path = PATH_PREPROCESSED,
    altitude_max: Annotated[float, isqx.aerospace.PRESSURE_ALTITUDE(isqx.M)] = ft2m(50000),
    speed_max: Annotated[float, isqx.SPEED(isqx.M_PERS)] = knot2mps(800),
    vertical_speed_max: Annotated[float, isqx.aerospace.VS(isqx.M_PERS)] = fpm2mps(8000),
    track_rate_max: Annotated[float, isqx.RAD_PERS] = 0.003,
    plot_every_n_flights: int | None = None,
):
    """Creates train/validation split of preprocessed trajectories.

    Handles the alignment of asynchronous data sources:

    1. Flight List: [takeoff, landing] constraints.
    2. Fuel Data: segment boundaries.
    3. ADS-B + ACARS: raw state observations.

    It produces the standard state vector $x_t$ required by
    [`microfuel.model.FuelBurnPredictor`][].

    Everything related to segments (e.g. whether a particular state vector is within [start, end])
    should be handled elsewhere. This function processes the *entire* trajectory.
    """
    path_base.mkdir(exist_ok=True, parents=True)

    flight_list_lf = raw.scan_flight_list(partition)
    fuel_lf = raw.scan_fuel(partition)

    flight_ids_with_fuel = (
        flight_list_lf.join(fuel_lf, on="flight_id")
        .select("flight_id")
        .unique()
        .sort("flight_id")
        .collect()
        .to_series()
        .shuffle(seed=seed)
        .to_list()
    )
    logger.info(
        f"found {len(flight_ids_with_fuel)} flights with fuel data in partition `{partition}`"
    )

    flight_id_to_flight: dict[FlightId, raw.FlightListRecord] = {
        row["flight_id"]: row  # type: ignore
        for row in flight_list_lf.collect().iter_rows(named=True)
    }
    flight_id_to_segment: dict[FlightId, tuple[pl.Series, pl.Series]] = {
        flight_id: (data["start"], data["end"])
        for (flight_id,), data in fuel_lf.collect().group_by("flight_id")
    }
    icao_to_coords: dict[AirportIcao, Coordinate2D[float]] = {
        row["icao"]: Coordinate2D(lng=row["longitude"], lat=row["latitude"])
        for row in raw.scan_airports().collect().iter_rows(named=True)
    }  # type: ignore

    trajectories_all: list[pl.LazyFrame] = []
    for i, flight_id in enumerate(track(flight_ids_with_fuel, description="processing flights")):
        traj_lf = raw.scan_trajectory(flight_id, partition).with_columns(
            pl.col("timestamp").dt.replace_time_zone("UTC")
        )  # raw file has naiive timestamps, cast early to avoid issues in era5 interpolation
        flight = flight_id_to_flight[flight_id]
        timestamp_takeoff = flight["takeoff"]
        timestamp_landed = flight["landed"]
        ac_type = flight["aircraft_type"]

        # segments can cover timestamps that are missing from the trajectory data, including
        # timestamps that start before takeoff or end after landing.
        # so we want to make sure one segment has at least 2 points (start and end) present
        timestamps_segment_start, timestamps_segment_end = flight_id_to_segment[flight_id]
        timestamps_required = (
            pl.concat(
                (
                    pl.Series((timestamp_takeoff, timestamp_landed)).dt.cast_time_unit("ns"),
                    timestamps_segment_start,
                    timestamps_segment_end,
                )
            )
            .unique()
            .sort()
        )

        # NOTE: discarding duplicate timestamps is a bad idea! sometimes the time gets "stuck"
        # and we lose a lot of useful information.
        traj_df = traj_lf.unique(subset=["timestamp"], keep="first").sort("timestamp").collect()
        timestamps_existing = traj_df.select("timestamp").to_series()
        # takeoff time in flight list usually precedes the first timestamp in trajectory data
        timestamps_missing = timestamps_required.filter(
            ~timestamps_required.is_in(timestamps_existing)
        )

        if timestamps_missing.len() > 0:
            # for required timestamps that are beyond what the trajectory data provides,
            # we assume they are stationary on the ground, zero filling features
            timestamp_gnd_start: datetime = min(timestamp_takeoff, timestamps_existing.min())  # type: ignore
            timestamp_gnd_end: datetime = max(timestamp_landed, timestamps_existing.max())  # type: ignore
            coord_origin, coord_dest = (
                icao_to_coords[flight["origin_icao"]],
                icao_to_coords[flight["destination_icao"]],
            )  # we dont need elevation since altitude is barometric
            trks: list[float] = traj_df.select("track").drop_nulls().to_series().to_list()
            trk_start, trk_end = (trks[0], trks[-1]) if len(trks) else (0.0, 0.0)  # ffill/bfill

            def _artificial(ts: datetime) -> raw.TrajectoryRecord:
                if ts <= timestamp_gnd_start:
                    val, lng, lat, trk = 0.0, coord_origin.lng, coord_origin.lat, trk_start
                elif ts >= timestamp_gnd_end:
                    val, lng, lat, trk = 0.0, coord_dest.lng, coord_dest.lat, trk_end
                else:
                    val, lng, lat, trk = None, None, None, None
                return raw.TrajectoryRecord(
                    timestamp=ts,
                    flight_id=flight_id,
                    typecode=ac_type,
                    latitude=lat,
                    longitude=lng,
                    altitude=val,
                    groundspeed=val,
                    track=trk,
                    vertical_rate=val,
                    mach=val,
                    TAS=val,
                    CAS=val,
                    source="artificial",
                )

            artificial_df = pl.DataFrame(_artificial(ts) for ts in timestamps_missing).with_columns(
                pl.col("timestamp").dt.cast_time_unit("ns")
            )
            full_traj_df = traj_df.vstack(artificial_df).sort("timestamp")
        else:
            full_traj_df = traj_df.sort("timestamp")

        timestamp_s = full_traj_df["timestamp"].dt.epoch(time_unit="ms").to_numpy() / 1000.0
        vs_raw = fpm2mps(full_traj_df["vertical_rate"].to_numpy())
        alt_raw = ft2m(full_traj_df["altitude"].to_numpy())
        gs_raw = knot2mps(full_traj_df["groundspeed"].to_numpy())
        track_raw_rad = np.deg2rad(full_traj_df["track"].to_numpy())
        lat_raw = full_traj_df["latitude"].to_numpy()
        lng_raw = full_traj_df["longitude"].to_numpy()

        vs_outlier_mask = (np.abs(vs_raw) > vertical_speed_max) | np.isnan(vs_raw)
        alt_outlier_mask = (alt_raw > altitude_max) | np.isnan(alt_raw)
        gs_outlier_mask = (gs_raw > speed_max) | np.isnan(gs_raw)
        track_outlier_mask = np.isnan(track_raw_rad)

        vs_with_nan = np.where(vs_outlier_mask, np.nan, vs_raw)
        alt_with_nan = np.where(alt_outlier_mask, np.nan, alt_raw)

        v_east_raw = gs_raw * np.sin(track_raw_rad)
        v_north_raw = gs_raw * np.cos(track_raw_rad)
        v_east_with_nan = np.where(gs_outlier_mask | track_outlier_mask, np.nan, v_east_raw)
        v_north_with_nan = np.where(gs_outlier_mask | track_outlier_mask, np.nan, v_north_raw)

        dts_s = np.diff(timestamp_s)
        alt_res = smooth_time_series(
            values=alt_with_nan,
            dts_s=dts_s,
            process_noise_variances=(1.0**2, 0.3**2),
            observation_noise_variance=4.0**2,
        )
        vs_res = smooth_time_series(
            values=vs_with_nan,
            dts_s=dts_s,
            process_noise_variances=(0.3**2, 0.1**2),
            observation_noise_variance=1.0**2,
        )
        v_east_res = smooth_time_series(
            values=v_east_with_nan,
            dts_s=dts_s,
            process_noise_variances=(1.0**2, 0.1**2),
            observation_noise_variance=6.0**2,
        )
        v_north_res = smooth_time_series(
            values=v_north_with_nan,
            dts_s=dts_s,
            process_noise_variances=(1.0**2, 0.1**2),
            observation_noise_variance=6.0**2,
        )

        v_east_smooth, v_east_dot_smooth = v_east_res.val, v_east_res.val_d
        v_north_smooth, v_north_dot_smooth = v_north_res.val, v_north_res.val_d
        gs_smooth = np.sqrt(v_east_smooth**2 + v_north_smooth**2)
        gs_smooth_outlier_mask = (gs_smooth > speed_max) | (gs_smooth < 0.0)
        track_rate_smooth = np.abs(
            (v_north_smooth * v_east_dot_smooth - v_east_smooth * v_north_dot_smooth)
            / np.clip(v_east_smooth**2 + v_north_smooth**2, 1e-6, None)
        )
        track_rate_outlier_mask = track_rate_smooth > track_rate_max
        gs_track_outlier_mask = gs_smooth_outlier_mask | track_rate_outlier_mask
        # ground speed and track rate are derived from ve and vn if either fails, set as outlier.
        gs_smooth[gs_track_outlier_mask] = np.nan
        track_rate_smooth[gs_track_outlier_mask] = np.nan

        v_east_interp = _np_interpolate(v_east_smooth, timestamp_s)
        v_north_interp = _np_interpolate(v_north_smooth, timestamp_s)

        # 0=N, 90=E
        track_interp_rad = np.arctan2(v_east_interp, v_north_interp)
        track_interp_deg = np.rad2deg(track_interp_rad)
        track_interp_deg = np.where(track_interp_deg < 0, track_interp_deg + 360, track_interp_deg)

        if i < 100 or (plot_every_n_flights is not None and i % plot_every_n_flights == 0):
            # import matplotlib
            import matplotlib.pyplot as plt
            from matplotlib.gridspec import GridSpec

            # matplotlib.use("WebAgg")
            from .. import PATH_PLOTS_OUTPUT

            N_PLOTS = 4
            fig = plt.figure(figsize=(9, 9 * N_PLOTS * 0.3), layout="tight")
            gs = GridSpec(N_PLOTS, 1, figure=fig)

            ax_alt = fig.add_subplot(gs[0])
            ax_vs = fig.add_subplot(gs[1], sharex=ax_alt)
            ax_gs = fig.add_subplot(gs[2], sharex=ax_alt)
            ax_track = fig.add_subplot(gs[3], sharex=ax_alt)

            for ax in [ax_alt, ax_vs, ax_gs, ax_track]:
                if ax != ax_track:
                    plt.setp(ax.get_xticklabels(), visible=False)
                ax.axvline(timestamp_takeoff.timestamp(), color="green", linewidth=0.5)
                ax.axvline(timestamp_landed.timestamp(), color="blue", linewidth=0.5)
                for j, (start_ts, end_ts) in enumerate(
                    zip(timestamps_segment_start, timestamps_segment_end)
                ):
                    ax.axvspan(start_ts.timestamp(), end_ts.timestamp(), color=f"C{j}", alpha=0.1)
                ax.grid(True, linewidth=0.2)

            ax_alt.plot(timestamp_s, alt_raw, "k.", markersize=2, alpha=0.3, label="raw altitude")
            ax_alt.plot(timestamp_s, alt_res.val, "r-", linewidth=0.5, label="smoothed altitude")
            alt_std = np.sqrt(alt_res.var_val)
            ax_alt.fill_between(
                timestamp_s,
                alt_res.val - alt_std,
                alt_res.val + alt_std,
                color="r",
                alpha=0.2,
                label=r"$\pm 1 \sigma$",
            )
            ax_alt.set_ylabel("altitude (m)")
            ax_alt.set_ylim(0, altitude_max)
            ax_alt.legend()

            ax_vs.plot(
                timestamp_s, vs_raw, "k.", markersize=2, alpha=0.3, label="raw vertical rate"
            )
            ax_vs.plot(
                timestamp_s,
                vs_res.val,
                "r-",
                linewidth=0.5,
                label="smoothed vertical rate",
            )
            ax_vs.plot(
                timestamp_s,
                alt_res.val_d,
                "b--",
                linewidth=0.5,
                label="smoothed altitude derivative",
            )
            vs_std = np.sqrt(vs_res.var_val)
            ax_vs.fill_between(
                timestamp_s,
                vs_res.val - vs_std,
                vs_res.val + vs_std,
                color="r",
                alpha=0.2,
                label=r"$\pm 1 \sigma$ (vs)",
            )
            ax_vs.set_ylabel("vertical rate (m/s)")
            ax_vs.set_ylim(-vertical_speed_max, vertical_speed_max)
            ax_vs.legend()

            ax_gs.plot(timestamp_s, gs_raw, "k.", markersize=2, alpha=0.3, label="raw groundspeed")
            ax_gs.plot(
                timestamp_s,
                gs_smooth,
                "r-",
                linewidth=0.5,
                label="smoothed groundspeed",
            )
            ax_gs.set_ylabel("groundspeed (m/s)")
            ax_gs.set_ylim(0, speed_max)
            ax_gs.legend()

            ax_track.plot(
                timestamp_s, track_interp_deg, "r.", markersize=0.5, label="smoothed track"
            )
            ax_track.set_ylabel("track (deg)")
            ax_track.set_ylim(0, 360)
            ax_track.legend(loc="upper left")
            ax_track_rate = ax_track.twinx()
            ax_track_rate.plot(
                timestamp_s,
                np.rad2deg(track_rate_smooth),
                "b.",
                markersize=2,
                label="track rate",
            )
            ax_track_rate.set_ylabel("track rate (deg/s)", color="b")
            ax_track_rate.tick_params(axis="y", labelcolor="b")
            ax_track_rate.legend(loc="lower right")
            ax_track_rate.set_ylim(-np.rad2deg(track_rate_max), np.rad2deg(track_rate_max))

            ax_track.set_xlabel("time (s)")
            fig.suptitle(f"flight id: {flight_id}")

            # plt.show()
            output_dir = PATH_PLOTS_OUTPUT / "preprocessed_trajectories"
            output_dir.mkdir(exist_ok=True, parents=True)
            output_path = output_dir / f"{partition}_{flight_id}.png"
            fig.savefig(output_path, dpi=300)
            plt.close(fig)

        # its possible that the very start of the smoothed data isn't processed
        # so we must interpolate.
        processed_traj_df = pl.DataFrame(
            {
                "timestamp": full_traj_df["timestamp"],
                "time_since_takeoff": (
                    full_traj_df["timestamp"] - timestamp_takeoff
                ).dt.total_seconds(fractional=True),
                "time_till_arrival": (
                    timestamp_landed - full_traj_df["timestamp"]
                ).dt.total_seconds(fractional=True),
                "latitude": _np_interpolate(lat_raw, timestamp_s),
                "longitude": _np_interpolate(lng_raw, timestamp_s),
                "vertical_rate": _np_interpolate(vs_res.val, timestamp_s),
                "vertical_rate_is_outlier": (vs_outlier_mask | np.isnan(vs_res.val)),
                "altitude": _np_interpolate(alt_res.val, timestamp_s),
                "altitude_is_outlier": (alt_outlier_mask | np.isnan(alt_res.val)),
                "groundspeed": _np_interpolate(gs_smooth, timestamp_s),
                "groundspeed_is_outlier": (
                    gs_outlier_mask | gs_smooth_outlier_mask | np.isnan(gs_smooth)
                ),
                "track": track_interp_deg,
                "track_rate": _np_interpolate(track_rate_smooth, timestamp_s),
                "track_rate_is_outlier": (track_rate_outlier_mask | np.isnan(track_rate_smooth)),
            }
        ).with_columns(pl.lit(flight_id).alias("flight_id"))

        trajectories_all.append(processed_traj_df.lazy())

    output_path = path_base / f"trajectories_{partition}.parquet"
    stop_event = multiprocessing.Event()
    monitor_process = multiprocessing.Process(
        target=_monitor_file_size, args=(output_path, stop_event)
    )
    monitor_process.start()
    try:
        pl.concat(trajectories_all).sink_parquet(output_path)
    finally:
        stop_event.set()
        monitor_process.join()
    logger.info(f"wrote state vectors to {output_path}")
altitude_to_pressure_std ¤
altitude_to_pressure_std(altitude_m)
Source code in src/microfuel/datasets/preprocessed.py
763
764
765
def altitude_to_pressure_std(altitude_m):
    # P ~ P0 * (1 - L*h/T0)^(gM/RL)
    return 1013.25 * (1 - 2.25577e-5 * altitude_m).pow(5.25588)
make_era5 ¤
make_era5(
    partition: Partition,
    *,
    path_base: Path = PATH_PREPROCESSED,
    path_raw_weather: Path = PATH_DATA_RAW / "era5",
)
Source code in src/microfuel/datasets/preprocessed.py
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
def make_era5(
    partition: Partition,
    *,
    path_base: Path = PATH_PREPROCESSED,
    path_raw_weather: Path = PATH_DATA_RAW / "era5",
):
    import xarray as xr

    path_base.mkdir(exist_ok=True, parents=True)
    path_out_dir = path_base / f"weather_{partition}"
    path_out_dir.mkdir(exist_ok=True, parents=True)

    traj_lf = pl.scan_parquet(path_base / f"trajectories_{partition}.parquet")

    df_coords = (
        traj_lf.select(
            "flight_id",
            "timestamp",
            "latitude",
            "longitude",
            "altitude",
        )
        .with_columns(
            date_key=pl.col("timestamp").dt.convert_time_zone("UTC").dt.date(),
            pressure_level=altitude_to_pressure_std(pl.col("altitude")),
            longitude_era5=(pl.col("longitude") + 360) % 360,
        )
        .sort("timestamp")
        .collect()
    )

    unique_dates = df_coords["date_key"].unique().sort()
    logger.info(f"extracting weather for {len(unique_dates)} unique days")

    def _process_variable(
        variable_dir: Path,
        variable_names: list[str],
        batch_times: tuple,
        targets: dict,
    ) -> np.ndarray | None:
        files = sorted(variable_dir.glob("*.nc"), key=lambda p: int(p.stem))
        if not files:
            return None

        levels = [int(p.stem) for p in files]

        try:
            ds = xr.open_mfdataset(
                files,
                combine="nested",
                concat_dim="level",
                parallel=False,
                chunks={"time": 1},
            )
            ds.coords["level"] = levels

            var_name = next((v for v in variable_names if v in ds), None)
            if not var_name:
                raise ValueError(f"vars {variable_names} not found in {variable_dir}")

            ds = ds.sortby(["time", "latitude", "level"])

            min_t, max_t = batch_times
            da_sliced = ds[var_name].sel(time=slice(min_t, max_t))

            da_loaded = da_sliced.load()
            ds.close()

            # NOTE: fill_value enables extrapolation.
            # consider a point at 23:55,
            # we would need 23:00 and 00:00 (the latter is located in a
            # different file) to interpolate. however, we cannot afford spending
            # 2x RAM, so we just have to deal with it for now.
            interp_res = da_loaded.interp(
                time=targets["time"],
                latitude=targets["latitude"],
                longitude=targets["longitude"],
                level=targets["level"],
                method="linear",
                kwargs={"bounds_error": False, "fill_value": None},
            )

            result = interp_res.values

            del da_loaded
            del ds
            gc.collect()

            return result

        except Exception:
            logger.error(f"failed to process {variable_dir}:\n{traceback.format_exc()}")
            return None

    for date_key in track(unique_dates, description="processing daily weather"):
        output_path = path_out_dir / f"{date_key}.parquet"
        if output_path.exists():
            continue

        batch = df_coords.filter(pl.col("date_key") == date_key)
        if batch.height == 0:
            continue

        logger.info(f"processing {date_key}: {batch.height} points")

        year = f"{date_key.year}"
        month = f"{date_key.month:02d}"
        day = f"{date_key.day:02d}"
        day_path = path_raw_weather / year / month / day

        if not day_path.exists():
            continue

        target_time = xr.DataArray(batch["timestamp"].to_numpy(), dims="points")
        target_lats = xr.DataArray(batch["latitude"].to_numpy(), dims="points")
        target_lons = xr.DataArray(batch["longitude_era5"].to_numpy(), dims="points")
        target_level = xr.DataArray(batch["pressure_level"].to_numpy(), dims="points")

        targets = {
            "time": target_time,
            "latitude": target_lats,
            "longitude": target_lons,
            "level": target_level,
        }

        min_time = batch["timestamp"].min().astimezone(timezone.utc).replace(  # type: ignore
            tzinfo=None
        ) - timedelta(hours=2)
        max_time = batch["timestamp"].max().astimezone(timezone.utc).replace(  # type: ignore
            tzinfo=None
        ) + timedelta(hours=2)

        u_values = _process_variable(
            day_path / "u_component_of_wind",
            ["u", "u_component_of_wind", "var131"],
            (min_time, max_time),
            targets,
        )
        if u_values is None:
            continue

        v_values = _process_variable(
            day_path / "v_component_of_wind",
            ["v", "v_component_of_wind", "var132"],
            (min_time, max_time),
            targets,
        )
        if v_values is None:
            continue

        chunk_res = pl.DataFrame(
            {
                "flight_id": batch["flight_id"],
                "timestamp": batch["timestamp"],
                "u_wind": u_values,
                "v_wind": v_values,
            }
        )
        chunk_res.write_parquet(output_path)

        del chunk_res
        del u_values
        del v_values
        gc.collect()

    logger.info(f"wrote daily weather chunks to {path_out_dir}")
make_derived_features ¤
make_derived_features(
    partition: Partition,
    *,
    path_base: Path = PATH_PREPROCESSED,
)

Warning

This function is unused. Integration of weather features is planned for the future

Source code in src/microfuel/datasets/preprocessed.py
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
def make_derived_features(partition: Partition, *, path_base: Path = PATH_PREPROCESSED):
    """
    !!! warning
        This function is unused. Integration of weather features is planned for the future
    """
    traj_lf = (
        pl.scan_parquet(path_base / f"trajectories_{partition}.parquet")
        .select("flight_id", "timestamp", "groundspeed", "track")
        .with_row_index("row_idx")
    )
    weather_lf = (
        pl.scan_parquet(path_base / f"weather_{partition}/*.parquet")
        .select("flight_id", "timestamp", "u_wind", "v_wind")
        .unique(subset=["flight_id", "timestamp"])
    )

    ve = pl.col("groundspeed") * (deg2rad(pl.col("track"))).sin()
    vn = pl.col("groundspeed") * (deg2rad(pl.col("track"))).cos()

    u, v = pl.col("u_wind"), pl.col("v_wind")

    va_e = ve - u
    va_n = vn - v

    tas = (va_e.pow(2) + va_n.pow(2)).sqrt()
    wind_dot = ve * u + vn * v

    derived_lf = (
        traj_lf.join(weather_lf, on=["flight_id", "timestamp"], how="left")
        .sort("row_idx")
        .select(
            "flight_id",
            "timestamp",
            tas.alias("true_airspeed"),
            wind_dot.alias("wind_dot_ground"),
        )
    )

    path_out = path_base / f"derived_{partition}.parquet"
    derived_lf.sink_parquet(path_out)
    logger.info(f"wrote derived features to {path_out}")
make_standardisation_stats ¤
make_standardisation_stats(
    partition: Partition,
    *,
    path_base: Path = PATH_PREPROCESSED,
)
Source code in src/microfuel/datasets/preprocessed.py
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
def make_standardisation_stats(
    partition: Partition,
    *,
    path_base: Path = PATH_PREPROCESSED,
):
    splits = load_splits(partition, path_base=path_base)
    train_segment_ids = splits["train"]
    logger.info(f"computing standardisation stats from {len(train_segment_ids)} train segments")

    fuel_lf = raw.scan_fuel(partition).filter(pl.col("idx").is_in(train_segment_ids))
    train_flight_ids_df = fuel_lf.select("flight_id").unique().collect()
    flight_list_df = (
        raw.scan_flight_list(partition)
        .filter(pl.col("flight_id").is_in(train_flight_ids_df["flight_id"]))
        .collect()
    )

    flight_duration_s = (flight_list_df["landed"] - flight_list_df["takeoff"]).dt.total_seconds(
        fractional=True
    )
    standardisation_stats: Stats = {
        "flight_duration": {
            "mean": flight_duration_s.mean(),
            "std": flight_duration_s.std(),
        }
    }  # type: ignore

    trajectory_iterator = TrajectoryIterator(
        partition=partition,
        segment_ids=train_segment_ids,
        start_to_end_only=True,
    )

    features_to_stat = [*STATE_FEATURES, "flight_progress"]
    running_stats = {
        feature: {"sum": 0.0, "sum_sq": 0.0, "count": 0} for feature in features_to_stat
    }

    flight_id_to_duration = {
        row["flight_id"]: (row["landed"] - row["takeoff"]).total_seconds()
        for row in flight_list_df.iter_rows(named=True)
    }

    for trajectory in track(trajectory_iterator, description="computing stats from train segments"):
        segment_df = trajectory.features_df
        count = len(segment_df)
        if count == 0:
            continue

        duration_s = flight_id_to_duration.get(trajectory.info["flight_id"])
        if duration_s is not None and duration_s > 0:
            progress = (segment_df["timestamp"] - trajectory.info["takeoff"]).dt.total_seconds(
                fractional=True
            ) / duration_s
            running_stats["flight_progress"]["sum"] += progress.sum()
            running_stats["flight_progress"]["sum_sq"] += (progress**2).sum()
            running_stats["flight_progress"]["count"] += count

        stats_for_segment = segment_df.select(
            [pl.sum(col).alias(f"{col}_sum") for col in STATE_FEATURES]
            + [(pl.col(col).pow(2)).sum().alias(f"{col}_sum_sq") for col in STATE_FEATURES]
        ).row(0, named=True)

        for feature in STATE_FEATURES:
            running_stats[feature]["sum"] += stats_for_segment[f"{feature}_sum"] or 0
            running_stats[feature]["sum_sq"] += stats_for_segment[f"{feature}_sum_sq"] or 0
            running_stats[feature]["count"] += count

    for feature, stats in running_stats.items():
        count = stats["count"]
        assert count > 2, f"not enough data for feature {feature}"
        mean = stats["sum"] / count
        variance = (stats["sum_sq"] / count) - (mean**2)
        assert variance >= 1e-9, f"variance for {feature} is negative or too small"
        std = np.sqrt(variance)

        standardisation_stats[feature] = {"mean": mean, "std": std}

    output_path = path_base / f"stats_{partition}.json"
    with open(output_path, "w") as f:
        json.dump(standardisation_stats, f, indent=2)
    logger.info(f"wrote stats to {output_path}")
load_splits ¤
load_splits(
    partition: Partition,
    *,
    path_base: Path = PATH_PREPROCESSED,
) -> dict[Split, list[SegmentId]]
Source code in src/microfuel/datasets/preprocessed.py
1081
1082
1083
1084
1085
def load_splits(
    partition: Partition, *, path_base: Path = PATH_PREPROCESSED
) -> dict[Split, list[SegmentId]]:
    with open(path_base / f"splits_{partition}.json") as f:
        return json.load(f)
load_standardisation_stats ¤
load_standardisation_stats(
    partition: Partition,
    *,
    path_base: Path = PATH_PREPROCESSED,
) -> Stats
Source code in src/microfuel/datasets/preprocessed.py
1088
1089
1090
1091
1092
def load_standardisation_stats(
    partition: Partition, *, path_base: Path = PATH_PREPROCESSED
) -> Stats:
    with open(path_base / f"stats_{partition}.json") as f:
        return json.load(f)
prepare_iterator_data ¤
prepare_iterator_data(
    partition: Partition,
    segment_ids: Collection[SegmentId] | None = None,
    stats: Stats | None = None,
    path_base: Path = PATH_PREPROCESSED,
) -> IteratorData

Prepares data required by the dataloader.

Source code in src/microfuel/datasets/preprocessed.py
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
def prepare_iterator_data(
    partition: Partition,
    segment_ids: Collection[SegmentId] | None = None,
    stats: Stats | None = None,
    path_base: Path = PATH_PREPROCESSED,
) -> IteratorData:
    """Prepares data required by the dataloader."""
    fuel_lf = raw.scan_fuel(partition)
    if segment_ids:
        fuel_lf = fuel_lf.filter(pl.col("idx").is_in(segment_ids))

    flight_list_lf = raw.scan_flight_list(partition)
    segments_df = (
        fuel_lf.join(
            flight_list_lf.select("flight_id", "takeoff", "landed", "aircraft_type"),
            on="flight_id",
        )
        .sort("flight_id")
        .collect()
    )

    # optimisation: select specific columns early to avoid OOM during large joins
    traj_cols = ["flight_id", "timestamp", *CORE_FEATURES]
    traj_lf = pl.scan_parquet(path_base / f"trajectories_{partition}.parquet").select(traj_cols)
    derived_path = path_base / f"derived_{partition}.parquet"
    if derived_path.exists() and WEATHER_FEATURES:
        derived_lf = pl.scan_parquet(derived_path)
        traj_lf = pl.concat(
            [traj_lf, derived_lf.select(WEATHER_FEATURES)], how="horizontal"
        )  # NOTE: we do not do a join here to avoid OOM: we already made sure rows align perfectly

    if stats is not None:
        standardisation_exprs = [
            ((pl.col(f) - stats[f]["mean"]) / stats[f]["std"]).alias(f) for f in STATE_FEATURES
        ]
        traj_lf = traj_lf.with_columns(standardisation_exprs)

    return IteratorData(segments_df=segments_df, traj_lf=traj_lf)

raw ¤

Config ¤

Bases: TypedDict

Source code in src/microfuel/datasets/raw.py
19
20
21
22
23
class Config(TypedDict):
    team_id: int
    team_name: str
    bucket_access_key: str
    bucket_access_secret: str
team_id instance-attribute ¤
team_id: int
team_name instance-attribute ¤
team_name: str
bucket_access_key instance-attribute ¤
bucket_access_key: str
bucket_access_secret instance-attribute ¤
bucket_access_secret: str
FuelRecord ¤

Bases: TypedDict

Fuel consumption data for a given flight interval. Path: fuel_{partition}.parquet.

Source code in src/microfuel/datasets/raw.py
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
class FuelRecord(TypedDict):
    """Fuel consumption data for a given flight interval. Path: `fuel_{partition}.parquet`."""

    idx: Annotated[int, pl.Int64]
    """Unique row identifier."""
    flight_id: Annotated[FlightId, pl.Utf8]
    """Links to the flight list and trajectory."""
    start: Annotated[datetime, pl.Datetime(time_zone="UTC")]
    """The start timestamp of the interval (usually an ACARS report)."""
    end: Annotated[datetime, pl.Datetime(time_zone="UTC")]
    """The end timestamp of the interval."""
    fuel_kg: Annotated[float, pl.Float64, isqx.MASS(isqx.KG)]
    """The target variable.

    !!! warning
        Note that this variable has quantisation artifacts: data is not a simple continuous
        distribution but a composite from at least two distinct sources:
        imperial (pounds) and metric (kilograms) units with a 2sf rounding step.
    """
idx instance-attribute ¤
idx: Annotated[int, pl.Int64]

Unique row identifier.

flight_id instance-attribute ¤
flight_id: Annotated[FlightId, pl.Utf8]

Links to the flight list and trajectory.

start instance-attribute ¤
start: Annotated[datetime, pl.Datetime(time_zone='UTC')]

The start timestamp of the interval (usually an ACARS report).

end instance-attribute ¤
end: Annotated[datetime, pl.Datetime(time_zone='UTC')]

The end timestamp of the interval.

fuel_kg instance-attribute ¤
fuel_kg: Annotated[float, pl.Float64, isqx.MASS(isqx.KG)]

The target variable.

Warning

Note that this variable has quantisation artifacts: data is not a simple continuous distribution but a composite from at least two distinct sources: imperial (pounds) and metric (kilograms) units with a 2sf rounding step.

FlightListRecord ¤

Bases: TypedDict

Metadata for each flight in the dataset. Path: flight_list_{partition}.parquet.

Source code in src/microfuel/datasets/raw.py
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
class FlightListRecord(TypedDict):
    """Metadata for each flight in the dataset. Path: `flight_list_{partition}.parquet`."""

    flight_id: Annotated[FlightId, pl.Utf8]
    """A unique identifier for the flight."""
    flight_date: Annotated[datetime, pl.Date]
    """The date of the flight."""
    takeoff: Annotated[datetime, pl.Datetime(time_zone="UTC")]
    """The timestamp of takeoff."""
    landed: Annotated[datetime, pl.Datetime(time_zone="UTC")]
    """The timestamp of landing."""
    origin_icao: Annotated[AirportIcao, pl.Utf8]
    """ICAO code for the departure airport."""
    destination_icao: Annotated[AirportIcao, pl.Utf8]
    """ICAO code for the destination airport."""
    aircraft_type: Annotated[AircraftType, pl.Utf8]
flight_id instance-attribute ¤
flight_id: Annotated[FlightId, pl.Utf8]

A unique identifier for the flight.

flight_date instance-attribute ¤
flight_date: Annotated[datetime, pl.Date]

The date of the flight.

takeoff instance-attribute ¤
takeoff: Annotated[datetime, pl.Datetime(time_zone='UTC')]

The timestamp of takeoff.

landed instance-attribute ¤
landed: Annotated[datetime, pl.Datetime(time_zone='UTC')]

The timestamp of landing.

origin_icao instance-attribute ¤
origin_icao: Annotated[AirportIcao, pl.Utf8]

ICAO code for the departure airport.

destination_icao instance-attribute ¤
destination_icao: Annotated[AirportIcao, pl.Utf8]

ICAO code for the destination airport.

aircraft_type instance-attribute ¤
aircraft_type: Annotated[AircraftType, pl.Utf8]
AirportRecord ¤

Bases: TypedDict

Airport metadata. Path: apt.parquet.

Source code in src/microfuel/datasets/raw.py
128
129
130
131
132
133
134
class AirportRecord(TypedDict):
    """Airport metadata. Path: `apt.parquet`."""

    icao: Annotated[AirportIcao, pl.Utf8]
    latitude: Annotated[float, pl.Float64, isqx.LATITUDE(isqx.DEG)]
    longitude: Annotated[float, pl.Float64, isqx.LONGITUDE(isqx.DEG)]
    elevation: Annotated[float, pl.Float64, isqx.aerospace.GEOMETRIC_ALTITUDE(isqx.usc.FT)] | None
icao instance-attribute ¤
icao: Annotated[AirportIcao, pl.Utf8]
latitude instance-attribute ¤
latitude: Annotated[
    float, pl.Float64, isqx.LATITUDE(isqx.DEG)
]
longitude instance-attribute ¤
longitude: Annotated[
    float, pl.Float64, isqx.LONGITUDE(isqx.DEG)
]
elevation instance-attribute ¤
elevation: (
    Annotated[
        float,
        pl.Float64,
        isqx.aerospace.GEOMETRIC_ALTITUDE(isqx.usc.FT),
    ]
    | None
)
TrajectoryRecord ¤

Bases: TypedDict

Flight trajectory data points. Path: flights_{partition}/{flight_id}.parquet.

Source code in src/microfuel/datasets/raw.py
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
class TrajectoryRecord(TypedDict):
    """Flight trajectory data points. Path: `flights_{partition}/{flight_id}.parquet`."""

    timestamp: Annotated[datetime, pl.Datetime(time_unit="ns", time_zone="UTC")]
    flight_id: Annotated[FlightId, pl.Utf8]
    typecode: Annotated[AircraftType, pl.Utf8]
    latitude: Annotated[float, pl.Float64, isqx.LATITUDE(isqx.DEG)] | None
    """Latitude, encoded via Compact Positional Reporting (CPR, tc=9-18, 20-22)
    We do not have access to uncertainty/quantisation, can be anywhere from:

    - navigational integrity category: nic=11 (rc < 7.5m)..nic=8 (rc < 185m)
    - navigational accuracy category: nacp=10 (epu < 10m)..nacp=8 (epu < 93m)"""
    longitude: Annotated[float, pl.Float64, isqx.LONGITUDE(isqx.DEG)] | None  # see above.
    altitude: Annotated[float, pl.Float64, isqx.aerospace.PRESSURE_ALTITUDE(isqx.usc.FT)] | None
    """Barometric altitude (tc=9-18, 12-bit field). Not to be confused with GNSS altitude (tc=20-22)

    Quantisation: 'q' bit (bit 8 of the field):
    - q=1: 25-foot increments. altitude = (decimal value of 11 bits) * 25 - 1000 ft.
    - q=0: 100-foot increments, using gray code for altitudes > 50,175 ft.

    Uncertainty: depends on barometric altitude quality (baq)."""
    groundspeed: Annotated[float, pl.Float64, isqx.aerospace.GS(isqx.usc.KNOT)] | None
    """Ground speed (GNSS or inertial reference system, tc=19, subtypes1-2).

    Not transmitted directly, encoded as two signed velocity components (east-west velocity,
    north-south velocity):

    - groundspeed = sqrt(vew^2 + vns^2)
    - track angle = atan2(vew, vns)

    Quantisation: 1 kt (subsonic), 4 kt (supersonic).
    Uncertainty: nacv=4 (< 0.3m/s), nacv=3 (< 1.0m/s), nacv=2 (< 3.0m/s), nacv=1 (< 10.0m/s)"""
    track: Annotated[float, pl.Float64, isqx.DEG] | None  # see above.
    vertical_rate: (
        Annotated[float, pl.Float64, isqx.aerospace.VS(isqx.usc.FT * isqx.MIN**-1)] | None
    )
    """Vertical rate (`vrsrc` specifies origin: GNSS or barometric, tc=19).

    a sign bit indicates climb or descent. a 9-bit value (vr) encodes the magnitude.
    vertical speed (ft/min) = 64 * (vr - 1).

    Uncertainty: linked to vertical component of nacv."""
    mach: Annotated[float, pl.Float64, isqx.MACH_NUMBER] | None
    """Mach number (Mode S, BDS 6,0, 10 bits, mb 25-34).

    Quantisation: 0.004."""
    TAS: Annotated[float, pl.Float64, isqx.aerospace.TAS(isqx.usc.KNOT)] | None
    """True airspeed.

    - ADS-B (tc=19, subtype 3/4) - Quantisation: 1 kt (subtype 3), 4 kt (subtype 4).
    - Mode S (BDS 5,0 track and turn report, 10 bits, mb 47-56) - Quantisation: 2 kt"""
    CAS: Annotated[float, pl.Float64, isqx.aerospace.CAS(isqx.usc.KNOT)] | None
    """Calibrated airspeed. Not broadcast, but likely derived from indicated airspeed (BDS 6,0).

    Quantisation: 1 kt."""
    source: Annotated[Literal["adsb", "acars", "artificial"], pl.Utf8]
    """Data source.

    Data from `adsb` and `acars` have different characteristics.
    `acars` data, for instance, may include `mach`, `TAS`, and `CAS`,
    which are not present in standard ADS-B reports.

    `artificial` data points are inserted from [microfuel.datasets.raw.FlightListRecord][] to aid
    interpolation.
    """
timestamp instance-attribute ¤
timestamp: Annotated[
    datetime, pl.Datetime(time_unit="ns", time_zone="UTC")
]
flight_id instance-attribute ¤
flight_id: Annotated[FlightId, pl.Utf8]
typecode instance-attribute ¤
typecode: Annotated[AircraftType, pl.Utf8]
latitude instance-attribute ¤
latitude: (
    Annotated[float, pl.Float64, isqx.LATITUDE(isqx.DEG)]
    | None
)

Latitude, encoded via Compact Positional Reporting (CPR, tc=9-18, 20-22) We do not have access to uncertainty/quantisation, can be anywhere from:

  • navigational integrity category: nic=11 (rc < 7.5m)..nic=8 (rc < 185m)
  • navigational accuracy category: nacp=10 (epu < 10m)..nacp=8 (epu < 93m)
longitude instance-attribute ¤
longitude: (
    Annotated[float, pl.Float64, isqx.LONGITUDE(isqx.DEG)]
    | None
)
altitude instance-attribute ¤
altitude: (
    Annotated[
        float,
        pl.Float64,
        isqx.aerospace.PRESSURE_ALTITUDE(isqx.usc.FT),
    ]
    | None
)

Barometric altitude (tc=9-18, 12-bit field). Not to be confused with GNSS altitude (tc=20-22)

Quantisation: 'q' bit (bit 8 of the field): - q=1: 25-foot increments. altitude = (decimal value of 11 bits) * 25 - 1000 ft. - q=0: 100-foot increments, using gray code for altitudes > 50,175 ft.

Uncertainty: depends on barometric altitude quality (baq).

groundspeed instance-attribute ¤
groundspeed: (
    Annotated[
        float, pl.Float64, isqx.aerospace.GS(isqx.usc.KNOT)
    ]
    | None
)

Ground speed (GNSS or inertial reference system, tc=19, subtypes1-2).

Not transmitted directly, encoded as two signed velocity components (east-west velocity, north-south velocity):

  • groundspeed = sqrt(vew^2 + vns^2)
  • track angle = atan2(vew, vns)

Quantisation: 1 kt (subsonic), 4 kt (supersonic). Uncertainty: nacv=4 (< 0.3m/s), nacv=3 (< 1.0m/s), nacv=2 (< 3.0m/s), nacv=1 (< 10.0m/s)

track instance-attribute ¤
track: Annotated[float, pl.Float64, isqx.DEG] | None
vertical_rate instance-attribute ¤
vertical_rate: (
    Annotated[
        float,
        pl.Float64,
        isqx.aerospace.VS(isqx.usc.FT * isqx.MIN**-1),
    ]
    | None
)

Vertical rate (vrsrc specifies origin: GNSS or barometric, tc=19).

a sign bit indicates climb or descent. a 9-bit value (vr) encodes the magnitude. vertical speed (ft/min) = 64 * (vr - 1).

Uncertainty: linked to vertical component of nacv.

mach instance-attribute ¤
mach: Annotated[float, pl.Float64, isqx.MACH_NUMBER] | None

Mach number (Mode S, BDS 6,0, 10 bits, mb 25-34).

Quantisation: 0.004.

TAS instance-attribute ¤
TAS: (
    Annotated[
        float, pl.Float64, isqx.aerospace.TAS(isqx.usc.KNOT)
    ]
    | None
)

True airspeed.

  • ADS-B (tc=19, subtype 3/4) - Quantisation: 1 kt (subtype 3), 4 kt (subtype 4).
  • Mode S (BDS 5,0 track and turn report, 10 bits, mb 47-56) - Quantisation: 2 kt
CAS instance-attribute ¤
CAS: (
    Annotated[
        float, pl.Float64, isqx.aerospace.CAS(isqx.usc.KNOT)
    ]
    | None
)

Calibrated airspeed. Not broadcast, but likely derived from indicated airspeed (BDS 6,0).

Quantisation: 1 kt.

source instance-attribute ¤
source: Annotated[
    Literal["adsb", "acars", "artificial"], pl.Utf8
]

Data source.

Data from adsb and acars have different characteristics. acars data, for instance, may include mach, TAS, and CAS, which are not present in standard ADS-B reports.

artificial data points are inserted from microfuel.datasets.raw.FlightListRecord to aid interpolation.

SubmissionRecord ¤

Bases: TypedDict

Source code in src/microfuel/datasets/raw.py
227
228
229
230
231
class SubmissionRecord(TypedDict):
    idx: Annotated[FlightId, pl.Utf8]
    """The row identifier."""
    predicted_fuel_kg: Annotated[float, pl.Float64, isqx.MASS(isqx.KG)]
    """Predicted fuel consumption (kilograms) for the given interval."""
idx instance-attribute ¤
idx: Annotated[FlightId, pl.Utf8]

The row identifier.

predicted_fuel_kg instance-attribute ¤
predicted_fuel_kg: Annotated[
    float, pl.Float64, isqx.MASS(isqx.KG)
]

Predicted fuel consumption (kilograms) for the given interval.

load_config ¤
load_config(fp: Path = PATH_DATA / 'config.toml') -> Config
Source code in src/microfuel/datasets/raw.py
26
27
28
29
30
31
32
def load_config(fp: Path = PATH_DATA / "config.toml") -> Config:
    if sys.version_info >= (3, 11):
        import tomllib
    else:
        import tomli as tomllib
    with open(fp, "rb") as f:
        return tomllib.load(f)  # type: ignore
setup_mc_alias ¤
setup_mc_alias(
    bucket_access_key: str,
    bucket_access_secret: str,
    endpoint_url: str = "https://s3.opensky-network.org:443",
    alias_name: str = "prc25",
) -> int
Source code in src/microfuel/datasets/raw.py
35
36
37
38
39
40
41
42
def setup_mc_alias(
    bucket_access_key: str,
    bucket_access_secret: str,
    endpoint_url: str = "https://s3.opensky-network.org:443",
    alias_name: str = "prc25",
) -> int:
    cmd = f"mc alias set {alias_name} {endpoint_url} {bucket_access_key} {bucket_access_secret}"
    return os.system(cmd)
download_from_s3 ¤
download_from_s3(
    bucket_access_key: str,
    bucket_access_secret: str,
    *,
    path_out: Path = PATH_DATA_RAW,
    bucket_name: str = "prc-2025-datasets",
    endpoint_url: str = "https://s3.opensky-network.org:443",
    alias_name: str = "prc2025",
) -> int

Download data from S3 using MinIO client. Not using boto3 because it is extremely slow.

Source code in src/microfuel/datasets/raw.py
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
def download_from_s3(
    bucket_access_key: str,
    bucket_access_secret: str,
    *,
    path_out: Path = PATH_DATA_RAW,
    bucket_name: str = "prc-2025-datasets",
    endpoint_url: str = "https://s3.opensky-network.org:443",
    alias_name: str = "prc2025",
) -> int:
    """Download data from S3 using MinIO client.
    Not using boto3 because it is extremely slow."""
    path_out.mkdir(parents=True, exist_ok=True)
    setup_mc_alias(bucket_access_key, bucket_access_secret, endpoint_url, alias_name)
    cmd = f"mc cp --recursive {alias_name}/{bucket_name}/ {path_out}/"
    return os.system(cmd)
scan_fuel ¤
scan_fuel(
    partition: Partition = "phase1",
    *,
    path_base: Path = PATH_DATA_RAW,
) -> Annotated[pl.LazyFrame, FuelRecord]
Source code in src/microfuel/datasets/raw.py
90
91
92
93
94
95
96
97
def scan_fuel(
    partition: Partition = "phase1", *, path_base: Path = PATH_DATA_RAW
) -> Annotated[pl.LazyFrame, FuelRecord]:
    fp = path_base / f"fuel_{partition}.parquet"
    # timestamps are in nanoseconds
    return pl.scan_parquet(fp).with_columns(
        pl.col("start").dt.replace_time_zone("UTC"), pl.col("end").dt.replace_time_zone("UTC")
    )
scan_flight_list ¤
scan_flight_list(
    partition: Partition = "phase1",
    *,
    path_base: Path = PATH_DATA_RAW,
) -> Annotated[pl.LazyFrame, FlightListRecord]
Source code in src/microfuel/datasets/raw.py
118
119
120
121
122
123
124
125
def scan_flight_list(
    partition: Partition = "phase1", *, path_base: Path = PATH_DATA_RAW
) -> Annotated[pl.LazyFrame, FlightListRecord]:
    fp = path_base / f"flight_list_{partition}.parquet"
    # timestamp are in seconds
    return pl.scan_parquet(fp).with_columns(
        pl.col("takeoff").dt.replace_time_zone("UTC"), pl.col("landed").dt.replace_time_zone("UTC")
    )
scan_airports ¤
scan_airports(
    *, path_base: Path = PATH_DATA_RAW
) -> Annotated[pl.LazyFrame, AirportRecord]
Source code in src/microfuel/datasets/raw.py
137
138
139
def scan_airports(*, path_base: Path = PATH_DATA_RAW) -> Annotated[pl.LazyFrame, AirportRecord]:
    fp = path_base / "apt.parquet"
    return pl.scan_parquet(fp)
scan_all_trajectories ¤
scan_all_trajectories(
    partition: Partition = "phase1",
    *,
    path_base: Path = PATH_DATA_RAW,
) -> Annotated[pl.LazyFrame, TrajectoryRecord]
Source code in src/microfuel/datasets/raw.py
209
210
211
212
213
214
215
def scan_all_trajectories(
    partition: Partition = "phase1", *, path_base: Path = PATH_DATA_RAW
) -> Annotated[pl.LazyFrame, TrajectoryRecord]:
    fp = path_base / f"flights_{partition}"
    return pl.scan_parquet(f"{fp}/*.parquet").with_columns(
        pl.col("timestamp").dt.replace_time_zone("UTC"),
    )
scan_trajectory ¤
scan_trajectory(
    flight_id: str,
    partition: Partition = "phase1",
    *,
    path_base: Path = PATH_DATA_RAW,
) -> Annotated[pl.LazyFrame, TrajectoryRecord]
Source code in src/microfuel/datasets/raw.py
218
219
220
221
222
223
224
def scan_trajectory(
    flight_id: str, partition: Partition = "phase1", *, path_base: Path = PATH_DATA_RAW
) -> Annotated[pl.LazyFrame, TrajectoryRecord]:
    fp = path_base / f"flights_{partition}" / f"{flight_id}.parquet"
    return pl.scan_parquet(fp).with_columns(
        pl.col("timestamp").dt.replace_time_zone("UTC"),
    )

hacks ¤

Vendored and patched versions of flash-linear-attention.

To avoid recompilation during variable-length training, we make the following changes:

  1. Causal Conv1D (https://github.com/fla-org/flash-linear-attention/blob/main/fla/modules/convolution.py)

  2. NB removed from autotune key

  3. NB removed from constexpr list in signature (kept as scalar)
  4. BT fixed to 64 in wrappers

  5. L2Norm (https://github.com/fla-org/flash-linear-attention/blob/main/fla/modules/l2norm.py)

  6. NB removed from autotune key

  7. NB removed from constexpr in signature
  8. T removed from constexpr in signature

  9. GatedNorm (https://github.com/fla-org/flash-linear-attention/blob/main/fla/modules/fused_norm_gate.py)

  10. NB removed from autotune key

  11. NB removed from constexpr in signature

NUM_WARPS_AUTOTUNE module-attribute ¤

NUM_WARPS_AUTOTUNE = (
    [2, 4, 8, 16] if is_amd else [4, 8, 16, 32]
)

FIXED_BT_CONV module-attribute ¤

FIXED_BT_CONV = 64

BT_LIST module-attribute ¤

BT_LIST = [8, 16, 32, 64, 128]

causal_conv1d_fwd_kernel ¤

causal_conv1d_fwd_kernel(
    x,
    y,
    weight,
    bias,
    residual,
    cu_seqlens,
    initial_state,
    chunk_indices,
    B,
    T,
    D: tl.constexpr,
    W: tl.constexpr,
    BT: tl.constexpr,
    BW: tl.constexpr,
    BD: tl.constexpr,
    NB,
    ACTIVATION: tl.constexpr,
    HAS_WEIGHT: tl.constexpr,
    HAS_BIAS: tl.constexpr,
    HAS_RESIDUAL: tl.constexpr,
    USE_INITIAL_STATE: tl.constexpr,
    IS_VARLEN: tl.constexpr,
)
Source code in src/microfuel/hacks.py
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
@triton.heuristics({
    'HAS_WEIGHT': lambda args: args['weight'] is not None,
    'HAS_BIAS': lambda args: args['bias'] is not None,
    'HAS_RESIDUAL': lambda args: args['residual'] is not None,
    'USE_INITIAL_STATE': lambda args: args['initial_state'] is not None,
    'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
})
@triton.autotune(
    configs=[
        triton.Config({'BD': BD}, num_warps=num_warps)
        for BD in [16, 32, 64, 128]
        for num_warps in NUM_WARPS_AUTOTUNE
    ],
    key=['D', 'W'],  # removed NB
    **autotune_cache_kwargs,
)
@triton.jit
def causal_conv1d_fwd_kernel(
    x, y, weight, bias, residual, cu_seqlens, initial_state, chunk_indices,
    B, T,
    D: tl.constexpr, W: tl.constexpr,
    BT: tl.constexpr, BW: tl.constexpr, BD: tl.constexpr,
    NB,  # removed constexpr
    ACTIVATION: tl.constexpr,
    HAS_WEIGHT: tl.constexpr, HAS_BIAS: tl.constexpr, HAS_RESIDUAL: tl.constexpr,
    USE_INITIAL_STATE: tl.constexpr, IS_VARLEN: tl.constexpr,
):
    i_d, i_t, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2)
    if IS_VARLEN:
        i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
        bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64)
        T = eos - bos
    else:
        i_n = i_b
        bos, eos = (i_b * T).to(tl.int64), (i_b * T + T).to(tl.int64)

    o_d = i_d * BD + tl.arange(0, BD)
    o_w = tl.arange(0, BW) + W - BW
    m_d = o_d < D
    m_w = o_w >= 0

    if HAS_WEIGHT:
        b_w = tl.load(weight + o_d[:, None] * W + o_w, mask=m_d[:, None] & m_w, other=0).to(tl.float32)

    b_y = tl.zeros((BT, BD), dtype=tl.float32)
    if not USE_INITIAL_STATE:
        for i_w in tl.static_range(-W + 1, 1):
            p_yi = tl.make_block_ptr(x + bos * D, (T, D), (D, 1), (i_t * BT + i_w, i_d * BD), (BT, BD), (1, 0))
            b_yi = tl.load(p_yi, boundary_check=(0, 1)).to(tl.float32)
            if HAS_WEIGHT:
                b_yi *= tl.sum(b_w * (o_w == (i_w + W - 1)), 1)
            b_y += b_yi
    elif i_t * BT >= W:
        for i_w in tl.static_range(-W + 1, 1):
            p_yi = tl.make_block_ptr(x + bos * D, (T, D), (D, 1), (i_t * BT + i_w, i_d * BD), (BT, BD), (1, 0))
            b_yi = tl.load(p_yi, boundary_check=(0, 1)).to(tl.float32)
            if HAS_WEIGHT:
                b_yi *= tl.sum(b_w * (o_w == (i_w + W - 1)), 1)
            b_y += b_yi
    else:
        o_t = i_t * BT + tl.arange(0, BT)
        for i_w in tl.static_range(-W + 1, 1):
            o_x = o_t + i_w
            m_x = ((o_x >= 0) & (o_x < T))[:, None] & m_d
            m_c = ((o_x + W >= 0) & (o_x < 0))[:, None] & m_d
            b_yi = tl.load(x + bos * D + o_x[:, None] * D + o_d, mask=m_x, other=0).to(tl.float32)
            b_yi += tl.load(initial_state + i_n * D*W + o_d * W + (o_x + W)[:, None], mask=m_c, other=0).to(tl.float32)
            if HAS_WEIGHT:
                b_yi *= tl.sum(b_w * (o_w == (i_w + W - 1)), 1)
            b_y += b_yi

    if HAS_BIAS:
        b_y += tl.load(bias + o_d, mask=m_d).to(tl.float32)
    if ACTIVATION == 'swish' or ACTIVATION == 'silu':
        b_y = b_y * tl.sigmoid(b_y)
    if HAS_RESIDUAL:
        p_residual = tl.make_block_ptr(residual + bos * D, (T, D), (D, 1), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
        b_residual = tl.load(p_residual, boundary_check=(0, 1))
        b_y += b_residual

    p_y = tl.make_block_ptr(y + bos * D, (T, D), (D, 1), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
    tl.store(p_y, tl.cast(b_y, dtype=p_y.dtype.element_ty, fp_downcast_rounding='rtne'), boundary_check=(0, 1))

causal_conv1d_bwd_kernel ¤

causal_conv1d_bwd_kernel(
    x,
    y,
    weight,
    initial_state,
    dh0,
    dht,
    dy,
    dx,
    dw,
    db,
    cu_seqlens,
    chunk_indices,
    B,
    T,
    D: tl.constexpr,
    W: tl.constexpr,
    BT: tl.constexpr,
    BW: tl.constexpr,
    BD: tl.constexpr,
    NB,
    ACTIVATION: tl.constexpr,
    HAS_WEIGHT: tl.constexpr,
    HAS_BIAS: tl.constexpr,
    USE_INITIAL_STATE: tl.constexpr,
    USE_FINAL_STATE: tl.constexpr,
    IS_VARLEN: tl.constexpr,
)
Source code in src/microfuel/hacks.py
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
@triton.heuristics({
    'HAS_WEIGHT': lambda args: args['dw'] is not None,
    'HAS_BIAS': lambda args: args['db'] is not None,
    'USE_INITIAL_STATE': lambda args: args['dh0'] is not None,
    'USE_FINAL_STATE': lambda args: args['dht'] is not None,
    'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
})
@triton.autotune(
    configs=[
        triton.Config({'BD': BD}, num_warps=num_warps)
        for BD in [16, 32, 64, 128]
        for num_warps in [4, 8, 16, 32]
    ],
    key=['D', 'W'],  # removed NB
    **autotune_cache_kwargs,
)
@triton.jit
def causal_conv1d_bwd_kernel(
    x, y, weight, initial_state, dh0, dht, dy, dx, dw, db, cu_seqlens, chunk_indices,
    B, T,
    D: tl.constexpr, W: tl.constexpr,
    BT: tl.constexpr, BW: tl.constexpr, BD: tl.constexpr,
    NB,  # removed constexpr
    ACTIVATION: tl.constexpr,
    HAS_WEIGHT: tl.constexpr, HAS_BIAS: tl.constexpr,
    USE_INITIAL_STATE: tl.constexpr, USE_FINAL_STATE: tl.constexpr,
    IS_VARLEN: tl.constexpr,
):
    i_d, i_t, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2)
    if IS_VARLEN:
        i_tg = i_t
        i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
        bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64)
        T = eos - bos
    else:
        i_tg = i_b * tl.num_programs(1) + i_t
        i_n = i_b
        bos, eos = (i_b * T).to(tl.int64), (i_b * T + T).to(tl.int64)

    o_d = i_d * BD + tl.arange(0, BD)
    o_w = tl.arange(0, BW) + W - BW
    m_d = o_d < D
    m_w = o_w >= 0

    if HAS_WEIGHT:
        p_x = tl.make_block_ptr(x + bos * D, (T, D), (D, 1), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
        b_x = tl.load(p_x, boundary_check=(0, 1))
        b_w = tl.load(weight + o_d[:, None] * W + o_w, mask=m_d[:, None] & m_w, other=0)

    b_dx = tl.zeros((BT, BD), dtype=tl.float32)
    if HAS_BIAS:
        b_db = tl.zeros((BD,), dtype=tl.float32)

    if not USE_FINAL_STATE:
        for i_w in tl.static_range(0, W):
            p_dy = tl.make_block_ptr(dy + bos * D, (T, D), (D, 1), (i_t * BT + i_w, i_d * BD), (BT, BD), (1, 0))
            b_dy = tl.load(p_dy, boundary_check=(0, 1)).to(tl.float32)
            if ACTIVATION == 'swish' or ACTIVATION == 'silu':
                p_y = tl.make_block_ptr(y + bos * D, (T, D), (D, 1), (i_t * BT + i_w, i_d * BD), (BT, BD), (1, 0))
                b_y = tl.load(p_y, boundary_check=(0, 1)).to(tl.float32)
                b_ys = tl.sigmoid(b_y)
                b_dy = b_dy * b_ys * (1 + b_y * (1 - b_ys))
            b_wdy = b_dy
            if HAS_WEIGHT:
                b_wdy = b_wdy * tl.sum(b_w * (o_w == (W - i_w - 1)), 1)
                b_dw = tl.sum(b_dy * b_x, 0)
                tl.store(dw + i_tg * D*W + o_d * W + W - i_w - 1, b_dw.to(dw.dtype.element_ty), mask=m_d)
            if HAS_BIAS and i_w == 0:
                b_db += tl.sum(b_dy, 0)
            b_dx += b_wdy
    elif i_t * BT >= W:
        for i_w in tl.static_range(0, W):
            p_dy = tl.make_block_ptr(dy + bos * D, (T, D), (D, 1), (i_t * BT + i_w, i_d * BD), (BT, BD), (1, 0))
            b_dy = tl.load(p_dy, boundary_check=(0, 1)).to(tl.float32)
            if ACTIVATION == 'swish' or ACTIVATION == 'silu':
                p_y = tl.make_block_ptr(y + bos * D, (T, D), (D, 1), (i_t * BT + i_w, i_d * BD), (BT, BD), (1, 0))
                b_y = tl.load(p_y, boundary_check=(0, 1)).to(tl.float32)
                b_ys = tl.sigmoid(b_y)
                b_dy = b_dy * b_ys * (1 + b_y * (1 - b_ys))
            b_wdy = b_dy
            if HAS_WEIGHT:
                b_wdy = b_wdy * tl.sum(b_w * (o_w == (W - i_w - 1)), 1)
                b_dw = tl.sum(b_dy * b_x, 0)
                tl.store(dw + i_tg * D*W + o_d * W + W - i_w - 1, b_dw.to(dw.dtype.element_ty), mask=m_d)
            if HAS_BIAS and i_w == 0:
                b_db += tl.sum(b_dy, 0)
            b_dx += b_wdy
    else:
        o_t = i_t * BT + tl.arange(0, BT)
        for i_w in tl.static_range(0, W):
            p_dy = tl.make_block_ptr(dy + bos * D, (T, D), (D, 1), (i_t * BT + i_w, i_d * BD), (BT, BD), (1, 0))
            b_dy_shift = tl.load(p_dy, boundary_check=(0, 1)).to(tl.float32)
            if ACTIVATION == 'swish' or ACTIVATION == 'silu':
                p_y = tl.make_block_ptr(y + bos * D, (T, D), (D, 1), (i_t * BT + i_w, i_d * BD), (BT, BD), (1, 0))
                b_y_shift = tl.load(p_y, boundary_check=(0, 1)).to(tl.float32)
                b_ys = tl.sigmoid(b_y_shift)
                b_dy_shift = b_dy_shift * b_ys * (1 + b_y_shift * (1 - b_ys))
            if HAS_WEIGHT:
                b_dw = tl.sum(b_dy_shift * b_x, 0)
                if USE_INITIAL_STATE:
                    mask_head_rows = (o_t < i_w)
                    b_dy_head = tl.load(dy + bos * D + o_t[:, None] * D + o_d, mask=(mask_head_rows[:, None] & m_d[None, :]), other=0.0).to(tl.float32)
                    if ACTIVATION == 'swish' or ACTIVATION == 'silu':
                        b_y_head = tl.load(y + bos * D + o_t[:, None] * D + o_d, mask=(mask_head_rows[:, None] & m_d[None, :]), other=0.0).to(tl.float32)
                        b_ys_head = tl.sigmoid(b_y_head)
                        b_dy_head = b_dy_head * b_ys_head * (1 + b_y_head * (1 - b_ys_head))
                    o_c = W - i_w + o_t
                    mask_c = (mask_head_rows & (o_c >= 1) & (o_c < W))
                    b_xc = tl.load(initial_state + i_n * D * W + o_d[None, :] * W + o_c[:, None], mask=(mask_c[:, None] & m_d[None, :]), other=0.0).to(tl.float32)
                    b_dw += tl.sum(b_dy_head * b_xc, 0)
                tl.store(dw + i_tg * D * W + o_d * W + W - i_w - 1, b_dw.to(dw.dtype.element_ty), mask=m_d)

            if HAS_BIAS and i_w == 0:
                b_db += tl.sum(b_dy_shift, 0)
            b_wdy = b_dy_shift if not HAS_WEIGHT else (b_dy_shift * tl.sum(b_w * (o_w == (W - i_w - 1)), 1))
            b_dx += b_wdy

        if USE_INITIAL_STATE:
            p_dy0 = tl.make_block_ptr(dy + bos * D, (T, D), (D, 1), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
            b_dy0 = tl.load(p_dy0, boundary_check=(0, 1)).to(tl.float32)
            if ACTIVATION == 'swish' or ACTIVATION == 'silu':
                p_y0 = tl.make_block_ptr(y + bos * D, (T, D), (D, 1), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
                b_y0 = tl.load(p_y0, boundary_check=(0, 1)).to(tl.float32)
                b_ys0 = tl.sigmoid(b_y0)
                b_dy0 = b_dy0 * b_ys0 * (1 + b_y0 * (1 - b_ys0))
            for i_w in tl.static_range(1, W):
                m_rows = (o_t < i_w)
                if HAS_WEIGHT:
                    w_idx_rows = i_w - 1 - o_t
                    w_mask = (o_w[None, :] == w_idx_rows[:, None])
                    w_pick = tl.sum(b_w[None, :, :] * w_mask[:, None, :], 2)
                else:
                    w_pick = 1.0
                contrib = (b_dy0 * w_pick).to(tl.float32)
                contrib = tl.where(m_rows[:, None] & m_d[None, :], contrib, 0.0)
                b_dh0_s = tl.sum(contrib, 0)
                tl.store(dh0 + i_t * B * D * W + i_n * D * W + o_d * W + i_w, b_dh0_s.to(dh0.dtype.element_ty, fp_downcast_rounding='rtne'), mask=m_d)

    if HAS_BIAS:
        b_db = tl.cast(b_db, dtype=db.dtype.element_ty, fp_downcast_rounding='rtne')
        tl.store(db + i_tg * D + o_d, b_db, mask=m_d)

    if USE_FINAL_STATE:
        if i_t * BT + BT >= T-W:
            start_tok = max(0, T - (W - 1))
            offset = i_t * BT + tl.arange(0, BT)
            tok_idx = offset - start_tok
            mask = (offset >= start_tok) & (offset < T)
            w_idx = 1 + tok_idx
            dht_off = i_n * D * W + o_d[None, :] * W + w_idx[:, None]
            b_dht = tl.load(dht + dht_off, mask=mask[:, None] & m_d[None, :], other=0.).to(tl.float32)
            b_dx += b_dht

    p_dx = tl.make_block_ptr(dx + bos * D, (T, D), (D, 1), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
    tl.store(p_dx, tl.cast(b_dx, dtype=p_dx.dtype.element_ty, fp_downcast_rounding='rtne'), boundary_check=(0, 1))

causal_conv1d_fwd ¤

causal_conv1d_fwd(
    x: torch.Tensor,
    weight: torch.Tensor,
    bias: torch.Tensor,
    residual: torch.Tensor,
    initial_state: torch.Tensor | None = None,
    output_final_state: bool = False,
    activation: str | None = None,
    cu_seqlens: torch.Tensor | None = None,
) -> torch.Tensor
Source code in src/microfuel/hacks.py
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
@input_guard
def causal_conv1d_fwd(
    x: torch.Tensor,
    weight: torch.Tensor,
    bias: torch.Tensor,
    residual: torch.Tensor,
    initial_state: torch.Tensor | None = None,
    output_final_state: bool = False,
    activation: str | None = None,
    cu_seqlens: torch.Tensor | None = None,
) -> torch.Tensor:
    shape = x.shape
    if x.shape[-1] != weight.shape[0]:
        x = rearrange(x, 'b t ... -> b t (...)')
    B, T, D, W = *x.shape, weight.shape[1]

    # HACK: fix BT to constant to avoid recompile on varlen batches
    BT = FIXED_BT_CONV
    BW = triton.next_power_of_2(W)
    chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
    NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT)
    NB = triton.cdiv(B*T, 1024)

    y = torch.empty_like(x)
    def grid(meta): return (triton.cdiv(D, meta['BD']), NT, B)
    causal_conv1d_fwd_kernel[grid](
        x=x, y=y, weight=weight, bias=bias, residual=residual,
        cu_seqlens=cu_seqlens, initial_state=initial_state, chunk_indices=chunk_indices,
        B=B, T=T, D=D, W=W, BT=BT, BW=BW, NB=NB,
        ACTIVATION=activation,
    )
    final_state = None
    if output_final_state:
        # NOTE: we use the original util since it doesn't depend on the kernel modifications
        final_state = convolution.causal_conv1d_update_states(
            x=x, state_len=W, initial_state=initial_state, cu_seqlens=cu_seqlens,
        )
    return y.view(shape), final_state

causal_conv1d_bwd ¤

causal_conv1d_bwd(
    x: torch.Tensor,
    dy: torch.Tensor,
    dht: torch.Tensor,
    weight: torch.Tensor | None = None,
    bias: torch.Tensor | None = None,
    residual: torch.Tensor | None = None,
    initial_state: torch.Tensor | None = None,
    activation: str | None = None,
    cu_seqlens: torch.Tensor | None = None,
)
Source code in src/microfuel/hacks.py
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
def causal_conv1d_bwd(
    x: torch.Tensor,
    dy: torch.Tensor,
    dht: torch.Tensor,
    weight: torch.Tensor | None = None,
    bias: torch.Tensor | None = None,
    residual: torch.Tensor | None = None,
    initial_state: torch.Tensor | None = None,
    activation: str | None = None,
    cu_seqlens: torch.Tensor | None = None,
):
    shape = x.shape
    if x.shape[-1] != weight.shape[0]:
        x = rearrange(x, 'b t ... -> b t (...)')
    B, T, D = x.shape
    W = weight.shape[1] if weight is not None else None

    # HACK: fix BT
    BT = FIXED_BT_CONV
    BW = triton.next_power_of_2(W)
    chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
    NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT)
    NB = triton.cdiv(B*T, 1024)

    y = None
    if activation is not None:
        y, _ = causal_conv1d_fwd(
            x=x, weight=weight, bias=bias, residual=None, initial_state=initial_state,
            activation=None, cu_seqlens=cu_seqlens, output_final_state=False,
        )
    dx = torch.empty_like(x)
    dw = weight.new_empty(B*NT, *weight.shape, dtype=torch.float) if weight is not None else None
    db = bias.new_empty(B*NT, *bias.shape, dtype=torch.float) if bias is not None else None
    dr = dy if residual is not None else None
    dh0 = initial_state.new_zeros(min(NT, triton.cdiv(W, BT)), *initial_state.shape) if initial_state is not None else None

    def grid(meta): return (triton.cdiv(D, meta['BD']), NT, B)
    causal_conv1d_bwd_kernel[grid](
        x=x, y=y, weight=weight, initial_state=initial_state, dh0=dh0, dht=dht,
        dy=dy, dx=dx, dw=dw, db=db, cu_seqlens=cu_seqlens, chunk_indices=chunk_indices,
        B=B, T=T, D=D, W=W, BT=BT, BW=BW, NB=NB,
        ACTIVATION=activation,
    )
    if weight is not None:
        dw = dw.sum(0).to(weight)
    if bias is not None:
        db = db.sum(0).to(bias)
    if initial_state is not None:
        dh0 = dh0.sum(0, dtype=torch.float32).to(initial_state)

    return dx.view(shape), dw, db, dr, dh0

l2norm_fwd_kernel ¤

l2norm_fwd_kernel(
    x,
    y,
    rstd,
    eps,
    T,
    D: tl.constexpr,
    BD: tl.constexpr,
    NB,
    BT: tl.constexpr,
)
Source code in src/microfuel/hacks.py
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
@triton.autotune(
    configs=[
        triton.Config({'BT': BT}, num_warps=num_warps)
        for num_warps in [1, 2, 4, 8, 16]
        for BT in BT_LIST
    ],
    key=['D'],  # NB removed
    **autotune_cache_kwargs,
)
@triton.jit
def l2norm_fwd_kernel(
    x, y, rstd, eps,
    T,  # removed constexpr
    D: tl.constexpr, BD: tl.constexpr, 
    NB,  # removed constexpr
    BT: tl.constexpr,
):
    i_t = tl.program_id(0)
    p_x = tl.make_block_ptr(x, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0))
    p_y = tl.make_block_ptr(y, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0))
    p_rstd = tl.make_block_ptr(rstd, (T,), (1,), (i_t * BT,), (BT,), (0,))

    b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32)
    b_rstd = 1 / tl.sqrt(tl.sum(b_x * b_x, 1) + eps)
    b_y = b_x * b_rstd[:, None]

    tl.store(p_y, b_y.to(p_y.dtype.element_ty), boundary_check=(0, 1))
    tl.store(p_rstd, b_rstd.to(p_rstd.dtype.element_ty), boundary_check=(0,))

l2norm_bwd_kernel ¤

l2norm_bwd_kernel(
    y,
    rstd,
    dy,
    dx,
    eps,
    T,
    D: tl.constexpr,
    BD: tl.constexpr,
    NB,
    BT: tl.constexpr,
)
Source code in src/microfuel/hacks.py
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
@triton.autotune(
    configs=[
        triton.Config({'BT': BT}, num_warps=num_warps)
        for num_warps in [1, 2, 4, 8, 16]
        for BT in BT_LIST
    ],
    key=['D'],  # NB removed
    **autotune_cache_kwargs,
)
@triton.jit
def l2norm_bwd_kernel(
    y, rstd, dy, dx, eps,
    T,  # removed constexpr
    D: tl.constexpr, BD: tl.constexpr,
    NB,  # removed constexpr
    BT: tl.constexpr,
):
    i_t = tl.program_id(0)
    p_y = tl.make_block_ptr(y, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0))
    p_rstd = tl.make_block_ptr(rstd, (T,), (1,), (i_t * BT,), (BT,), (0,))
    p_dy = tl.make_block_ptr(dy, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0))
    p_dx = tl.make_block_ptr(dx, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0))

    b_y = tl.load(p_y, boundary_check=(0, 1)).to(tl.float32)
    b_rstd = tl.load(p_rstd, boundary_check=(0,)).to(tl.float32)
    b_dy = tl.load(p_dy, boundary_check=(0, 1)).to(tl.float32)
    b_dx = b_dy * b_rstd[:, None] - tl.sum(b_dy * b_y, 1)[:, None] * b_y * b_rstd[:, None]
    tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), boundary_check=(0, 1))

l2norm_fwd ¤

l2norm_fwd(
    x: torch.Tensor,
    eps: float = 1e-06,
    output_dtype: torch.dtype | None = None,
)
Source code in src/microfuel/hacks.py
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
def l2norm_fwd(
    x: torch.Tensor,
    eps: float = 1e-6,
    output_dtype: torch.dtype | None = None,
):
    x_shape_og = x.shape
    x = x.view(-1, x.shape[-1])
    if output_dtype is None:
        y = torch.empty_like(x)
    else:
        y = torch.empty_like(x, dtype=output_dtype)
    assert y.stride(-1) == 1
    T, D = x.shape[0], x.shape[-1]
    MAX_FUSED_SIZE = 65536 // x.element_size()
    BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D))
    if D > BD:
        raise RuntimeError("This layer doesn't support feature dim >= 64KB.")

    rstd = torch.empty((T,), dtype=torch.float32, device=x.device)
    if D <= 512:
        NB = triton.cdiv(T, 2048)
        def grid(meta): return (triton.cdiv(T, meta['BT']), )
        l2norm_fwd_kernel[grid](
            x=x, y=y, rstd=rstd, eps=eps, T=T, D=D, BD=BD, NB=NB,
        )
    else:
        # fallback to original kernel1 for large D (no NB/T constexpr issues there)
        l2norm.l2norm_fwd_kernel1[(T,)](
            x=x, y=y, rstd=rstd, eps=eps, D=D, BD=BD,
        )
    return y.view(x_shape_og), rstd.view(x_shape_og[:-1])

l2norm_bwd ¤

l2norm_bwd(
    y: torch.Tensor,
    rstd: torch.Tensor,
    dy: torch.Tensor,
    eps: float = 1e-06,
)
Source code in src/microfuel/hacks.py
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
def l2norm_bwd(
    y: torch.Tensor,
    rstd: torch.Tensor,
    dy: torch.Tensor,
    eps: float = 1e-6,
):
    y_shape_og = y.shape
    y = y.view(-1, dy.shape[-1])
    dy = dy.view(-1, dy.shape[-1])
    assert dy.shape == y.shape
    dx = torch.empty_like(y)
    T, D = y.shape[0], y.shape[-1]
    MAX_FUSED_SIZE = 65536 // y.element_size()
    BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D))
    if D > BD:
        raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")

    if D <= 512:
        NB = triton.cdiv(T, 2048)
        def grid(meta): return (triton.cdiv(T, meta['BT']), )
        l2norm_bwd_kernel[grid](
            y=y, rstd=rstd, dy=dy, dx=dx, eps=eps, T=T, D=D, BD=BD, NB=NB,
        )
    else:
        l2norm.l2norm_bwd_kernel1[(T,)](
            y=y, rstd=rstd, dy=dy, dx=dx, eps=eps, D=D, BD=BD,
        )

    return dx.view(y_shape_og)

layer_norm_gated_fwd_kernel ¤

layer_norm_gated_fwd_kernel(
    x,
    g,
    y,
    w,
    b,
    residual,
    residual_out,
    mean,
    rstd,
    eps,
    T,
    D: tl.constexpr,
    BT: tl.constexpr,
    BD: tl.constexpr,
    NB,
    ACTIVATION: tl.constexpr,
    IS_RMS_NORM: tl.constexpr,
    STORE_RESIDUAL_OUT: tl.constexpr,
    HAS_RESIDUAL: tl.constexpr,
    HAS_WEIGHT: tl.constexpr,
    HAS_BIAS: tl.constexpr,
)
Source code in src/microfuel/hacks.py
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
@triton.heuristics({
    'STORE_RESIDUAL_OUT': lambda args: args['residual_out'] is not None,
    'HAS_RESIDUAL': lambda args: args['residual'] is not None,
    'HAS_WEIGHT': lambda args: args['w'] is not None,
    'HAS_BIAS': lambda args: args['b'] is not None,
})
@triton.autotune(
    configs=[
        triton.Config({'BT': BT}, num_warps=num_warps)
        for BT in [16, 32, 64]
        for num_warps in [4, 8, 16]
    ],
    key=['D', 'IS_RMS_NORM', 'STORE_RESIDUAL_OUT', 'HAS_RESIDUAL', 'HAS_WEIGHT'],  # NB removed
    **autotune_cache_kwargs,
)
@triton.jit
def layer_norm_gated_fwd_kernel(
    x, g, y, w, b, residual, residual_out, mean, rstd, eps,
    T,
    D: tl.constexpr, BT: tl.constexpr, BD: tl.constexpr,
    NB,  # removed constexpr
    ACTIVATION: tl.constexpr, IS_RMS_NORM: tl.constexpr,
    STORE_RESIDUAL_OUT: tl.constexpr, HAS_RESIDUAL: tl.constexpr,
    HAS_WEIGHT: tl.constexpr, HAS_BIAS: tl.constexpr,
):
    i_t = tl.program_id(0)
    o_d = tl.arange(0, BD)
    m_d = o_d < D

    p_x = tl.make_block_ptr(x, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0))
    b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32)
    if HAS_RESIDUAL:
        p_res = tl.make_block_ptr(residual, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0))
        b_x += tl.load(p_res, boundary_check=(0, 1)).to(tl.float32)
    if STORE_RESIDUAL_OUT:
        p_res_out = tl.make_block_ptr(residual_out, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0))
        tl.store(p_res_out, b_x.to(p_res_out.dtype.element_ty), boundary_check=(0, 1))
    if not IS_RMS_NORM:
        b_mean = tl.sum(b_x, axis=1) / D
        p_mean = tl.make_block_ptr(mean, (T,), (1,), (i_t * BT,), (BT,), (0,))
        tl.store(p_mean, b_mean.to(p_mean.dtype.element_ty), boundary_check=(0,))
        b_xbar = tl.where(m_d[None, :], b_x - b_mean[:, None], 0.0)
        b_var = tl.sum(b_xbar * b_xbar, axis=1) / D
    else:
        b_xbar = tl.where(m_d[None, :], b_x, 0.0)
        b_var = tl.sum(b_xbar * b_xbar, axis=1) / D
    b_rstd = 1 / tl.sqrt(b_var + eps)
    p_rstd = tl.make_block_ptr(rstd, (T,), (1,), (i_t * BT,), (BT,), (0,))
    tl.store(p_rstd, b_rstd.to(p_rstd.dtype.element_ty), boundary_check=(0,))

    if HAS_WEIGHT:
        b_w = tl.load(w + o_d, mask=m_d).to(tl.float32)
    if HAS_BIAS:
        b_b = tl.load(b + o_d, mask=m_d).to(tl.float32)
    b_x_hat = (b_x - b_mean[:, None]) * b_rstd[:, None] if not IS_RMS_NORM else b_x * b_rstd[:, None]
    b_y = b_x_hat * b_w[None, :] if HAS_WEIGHT else b_x_hat
    if HAS_BIAS:
        b_y = b_y + b_b[None, :]

    p_g = tl.make_block_ptr(g, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0))
    b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32)
    if ACTIVATION == 'swish' or ACTIVATION == 'silu':
        b_y = b_y * b_g * tl.sigmoid(b_g)
    elif ACTIVATION == 'sigmoid':
        b_y = b_y * tl.sigmoid(b_g)

    p_y = tl.make_block_ptr(y, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0))
    tl.store(p_y, b_y.to(p_y.dtype.element_ty), boundary_check=(0, 1))

layer_norm_gated_bwd_kernel ¤

layer_norm_gated_bwd_kernel(
    x,
    g,
    w,
    b,
    y,
    dy,
    dx,
    dg,
    dw,
    db,
    dresidual,
    dresidual_in,
    mean,
    rstd,
    T,
    BS,
    D: tl.constexpr,
    BT: tl.constexpr,
    BD: tl.constexpr,
    NB,
    ACTIVATION: tl.constexpr,
    IS_RMS_NORM: tl.constexpr,
    STORE_DRESIDUAL: tl.constexpr,
    HAS_DRESIDUAL: tl.constexpr,
    HAS_WEIGHT: tl.constexpr,
    HAS_BIAS: tl.constexpr,
    RECOMPUTE_OUTPUT: tl.constexpr,
)
Source code in src/microfuel/hacks.py
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
@triton.heuristics({
    'HAS_DRESIDUAL': lambda args: args['dresidual'] is not None,
    'HAS_WEIGHT': lambda args: args['w'] is not None,
    'HAS_BIAS': lambda args: args['b'] is not None,
    'RECOMPUTE_OUTPUT': lambda args: args['y'] is not None,
})
@triton.autotune(
    configs=[
        triton.Config({'BT': BT}, num_warps=num_warps)
        for BT in [16, 32, 64]
        for num_warps in [4, 8, 16]
    ],
    key=['D', 'IS_RMS_NORM', 'HAS_DRESIDUAL', 'HAS_WEIGHT'],  # NB removed
    **autotune_cache_kwargs,
)
@triton.jit
def layer_norm_gated_bwd_kernel(
    x, g, w, b, y, dy, dx, dg, dw, db, dresidual, dresidual_in, mean, rstd,
    T, BS,
    D: tl.constexpr, BT: tl.constexpr, BD: tl.constexpr,
    NB,  # removed constexpr
    ACTIVATION: tl.constexpr, IS_RMS_NORM: tl.constexpr,
    STORE_DRESIDUAL: tl.constexpr, HAS_DRESIDUAL: tl.constexpr,
    HAS_WEIGHT: tl.constexpr, HAS_BIAS: tl.constexpr, RECOMPUTE_OUTPUT: tl.constexpr,
):
    i_s = tl.program_id(0)
    o_d = tl.arange(0, BD)
    m_d = o_d < D
    if HAS_WEIGHT:
        b_w = tl.load(w + o_d, mask=m_d).to(tl.float32)
        b_dw = tl.zeros((BT, BD), dtype=tl.float32)
    if HAS_BIAS:
        b_b = tl.load(b + o_d, mask=m_d, other=0.0).to(tl.float32)
        b_db = tl.zeros((BT, BD), dtype=tl.float32)

    T = min(i_s * BS + BS, T)
    for i_t in range(i_s * BS, T, BT):
        p_x = tl.make_block_ptr(x, (T, D), (D, 1), (i_t, 0), (BT, BD), (1, 0))
        p_g = tl.make_block_ptr(g, (T, D), (D, 1), (i_t, 0), (BT, BD), (1, 0))
        p_dy = tl.make_block_ptr(dy, (T, D), (D, 1), (i_t, 0), (BT, BD), (1, 0))
        p_dx = tl.make_block_ptr(dx, (T, D), (D, 1), (i_t, 0), (BT, BD), (1, 0))
        p_dg = tl.make_block_ptr(dg, (T, D), (D, 1), (i_t, 0), (BT, BD), (1, 0))
        b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32)
        b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32)
        b_dy = tl.load(p_dy, boundary_check=(0, 1)).to(tl.float32)

        if not IS_RMS_NORM:
            p_mean = tl.make_block_ptr(mean, (T,), (1,), (i_t,), (BT,), (0,))
            b_mean = tl.load(p_mean, boundary_check=(0,))
        p_rstd = tl.make_block_ptr(rstd, (T,), (1,), (i_t,), (BT,), (0,))
        b_rstd = tl.load(p_rstd, boundary_check=(0,))
        b_xhat = (b_x - b_mean[:, None]) * b_rstd[:, None] if not IS_RMS_NORM else b_x * b_rstd[:, None]
        b_xhat = tl.where(m_d[None, :], b_xhat, 0.0)

        b_y = b_xhat * b_w[None, :] if HAS_WEIGHT else b_xhat
        if HAS_BIAS:
            b_y = b_y + b_b[None, :]
        if RECOMPUTE_OUTPUT:
            p_y = tl.make_block_ptr(y, (T, D), (D, 1), (i_t, 0), (BT, BD), (1, 0))
            tl.store(p_y, b_y.to(p_y.dtype.element_ty), boundary_check=(0, 1))

        b_sigmoid_g = tl.sigmoid(b_g)
        if ACTIVATION == 'swish' or ACTIVATION == 'silu':
            b_dg = b_dy * b_y * (b_sigmoid_g + b_g * b_sigmoid_g * (1 - b_sigmoid_g))
            b_dy = b_dy * b_g * b_sigmoid_g
        elif ACTIVATION == 'sigmoid':
            b_dg = b_dy * b_y * b_sigmoid_g * (1 - b_sigmoid_g)
            b_dy = b_dy * b_sigmoid_g
        b_wdy = b_dy

        if HAS_WEIGHT or HAS_BIAS:
            m_t = (i_t + tl.arange(0, BT)) < T
        if HAS_WEIGHT:
            b_wdy = b_dy * b_w
            b_dw += tl.where(m_t[:, None], b_dy * b_xhat, 0.0)
        if HAS_BIAS:
            b_db += tl.where(m_t[:, None], b_dy, 0.0)
        if not IS_RMS_NORM:
            b_c1 = tl.sum(b_xhat * b_wdy, axis=1) / D
            b_c2 = tl.sum(b_wdy, axis=1) / D
            b_dx = (b_wdy - (b_xhat * b_c1[:, None] + b_c2[:, None])) * b_rstd[:, None]
        else:
            b_c1 = tl.sum(b_xhat * b_wdy, axis=1) / D
            b_dx = (b_wdy - b_xhat * b_c1[:, None]) * b_rstd[:, None]
        if HAS_DRESIDUAL:
            p_dres = tl.make_block_ptr(dresidual, (T, D), (D, 1), (i_t, 0), (BT, BD), (1, 0))
            b_dres = tl.load(p_dres, boundary_check=(0, 1)).to(tl.float32)
            b_dx += b_dres
        if STORE_DRESIDUAL:
            p_dres_in = tl.make_block_ptr(dresidual_in, (T, D), (D, 1), (i_t, 0), (BT, BD), (1, 0))
            tl.store(p_dres_in, b_dx.to(p_dres_in.dtype.element_ty), boundary_check=(0, 1))

        tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), boundary_check=(0, 1))
        tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1))

    if HAS_WEIGHT:
        tl.store(dw + i_s * D + o_d, tl.sum(b_dw, axis=0), mask=m_d)
    if HAS_BIAS:
        tl.store(db + i_s * D + o_d, tl.sum(b_db, axis=0), mask=m_d)

layer_norm_gated_fwd ¤

layer_norm_gated_fwd(
    x: torch.Tensor,
    g: torch.Tensor,
    weight: torch.Tensor,
    bias: torch.Tensor,
    activation: str = "swish",
    eps: float = 1e-05,
    residual: torch.Tensor = None,
    out_dtype: torch.dtype = None,
    residual_dtype: torch.dtype = None,
    is_rms_norm: bool = False,
)
Source code in src/microfuel/hacks.py
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
def layer_norm_gated_fwd(
    x: torch.Tensor,
    g: torch.Tensor,
    weight: torch.Tensor,
    bias: torch.Tensor,
    activation: str = 'swish',
    eps: float = 1e-5,
    residual: torch.Tensor = None,
    out_dtype: torch.dtype = None,
    residual_dtype: torch.dtype = None,
    is_rms_norm: bool = False,
):
    if residual is not None:
        residual_dtype = residual.dtype
    T, D = x.shape
    if residual is not None:
        assert residual.shape == (T, D)
    if weight is not None:
        assert weight.shape == (D,)
    if bias is not None:
        assert bias.shape == (D,)
    y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
    if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype):
        residual_out = torch.empty(T, D, device=x.device, dtype=residual_dtype)
    else:
        residual_out = None
    mean = torch.empty((T,), dtype=torch.float, device=x.device) if not is_rms_norm else None
    rstd = torch.empty((T,), dtype=torch.float, device=x.device)
    MAX_FUSED_SIZE = 65536 // x.element_size()
    BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D))
    if D > BD:
        raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")

    if D <= 512:
        NB = triton.cdiv(T, 2048)
        def grid(meta): return (triton.cdiv(T, meta['BT']),)
        layer_norm_gated_fwd_kernel[grid](
            x=x, g=g, y=y, w=weight, b=bias, residual=residual, residual_out=residual_out,
            mean=mean, rstd=rstd, eps=eps, T=T, D=D, BD=BD, NB=NB,
            ACTIVATION=activation, IS_RMS_NORM=is_rms_norm,
        )
    else:
        fused_norm_gate.layer_norm_gated_fwd_kernel1[(T,)](
            x=x, g=g, y=y, w=weight, b=bias, residual=residual, residual_out=residual_out,
            mean=mean, rstd=rstd, eps=eps, D=D, BD=BD,
            ACTIVATION=activation, IS_RMS_NORM=is_rms_norm,
        )
    return y, mean, rstd, residual_out if residual_out is not None else x

layer_norm_gated_bwd ¤

layer_norm_gated_bwd(
    dy: torch.Tensor,
    x: torch.Tensor,
    g: torch.Tensor,
    weight: torch.Tensor,
    bias: torch.Tensor,
    activation: str = "swish",
    eps: float = 1e-05,
    mean: torch.Tensor = None,
    rstd: torch.Tensor = None,
    dresidual: torch.Tensor = None,
    has_residual: bool = False,
    is_rms_norm: bool = False,
    x_dtype: torch.dtype = None,
    recompute_output: bool = False,
)
Source code in src/microfuel/hacks.py
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
def layer_norm_gated_bwd(
    dy: torch.Tensor,
    x: torch.Tensor,
    g: torch.Tensor,
    weight: torch.Tensor,
    bias: torch.Tensor,
    activation: str = 'swish',
    eps: float = 1e-5,
    mean: torch.Tensor = None,
    rstd: torch.Tensor = None,
    dresidual: torch.Tensor = None,
    has_residual: bool = False,
    is_rms_norm: bool = False,
    x_dtype: torch.dtype = None,
    recompute_output: bool = False,
):
    T, D = x.shape
    dx = torch.empty_like(x) if x_dtype is None else torch.empty(T, D, dtype=x_dtype, device=x.device)
    dg = torch.empty_like(g) if x_dtype is None else torch.empty(T, D, dtype=x_dtype, device=x.device)
    dresidual_in = torch.empty_like(x) if has_residual and dx.dtype != x.dtype else None
    y = torch.empty(T, D, dtype=dy.dtype, device=dy.device) if recompute_output else None

    MAX_FUSED_SIZE = 65536 // x.element_size()
    BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D))
    if D > BD:
        raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
    NS = get_multiprocessor_count(x.device.index)
    BS = math.ceil(T / NS)

    dw = torch.empty((NS, D), dtype=torch.float, device=weight.device) if weight is not None else None
    db = torch.empty((NS, D), dtype=torch.float, device=bias.device) if bias is not None else None
    grid = (NS,)

    if D <= 512:
        NB = triton.cdiv(T, 2048)
        layer_norm_gated_bwd_kernel[grid](
            x=x, g=g, w=weight, b=bias, y=y, dy=dy, dx=dx, dg=dg, dw=dw, db=db,
            dresidual=dresidual, dresidual_in=dresidual_in, mean=mean, rstd=rstd,
            T=T, D=D, BS=BS, BD=BD, NB=NB,
            ACTIVATION=activation, IS_RMS_NORM=is_rms_norm,
            STORE_DRESIDUAL=dresidual_in is not None,
        )
    else:
        fused_norm_gate.layer_norm_gated_bwd_kernel1[grid](
            x=x, g=g, w=weight, b=bias, y=y, dy=dy, dx=dx, dg=dg, dw=dw, db=db,
            dresidual=dresidual, dresidual_in=dresidual_in, mean=mean, rstd=rstd,
            T=T, D=D, BS=BS, BD=BD,
            ACTIVATION=activation, IS_RMS_NORM=is_rms_norm,
            STORE_DRESIDUAL=dresidual_in is not None,
        )
    dw = dw.sum(0).to(weight.dtype) if weight is not None else None
    db = db.sum(0).to(bias.dtype) if bias is not None else None
    if has_residual and dx.dtype == x.dtype:
        dresidual_in = dx
    return (dx, dg, dw, db, dresidual_in) if not recompute_output else (dx, dg, dw, db, dresidual_in, y)

install_optimized_kernels_ ¤

install_optimized_kernels_()

Patches FLA modules with optimized versions to reduce JIT recompilation.

Source code in src/microfuel/hacks.py
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
def install_optimized_kernels_():
    """Patches FLA modules with optimized versions to reduce JIT recompilation."""
    convolution.causal_conv1d_fwd_kernel = causal_conv1d_fwd_kernel
    convolution.causal_conv1d_bwd_kernel = causal_conv1d_bwd_kernel
    convolution.causal_conv1d_fwd = causal_conv1d_fwd
    convolution.causal_conv1d_bwd = causal_conv1d_bwd

    l2norm.l2norm_fwd_kernel = l2norm_fwd_kernel
    l2norm.l2norm_bwd_kernel = l2norm_bwd_kernel
    l2norm.l2norm_fwd = l2norm_fwd
    l2norm.l2norm_bwd = l2norm_bwd

    fused_norm_gate.layer_norm_gated_fwd_kernel = layer_norm_gated_fwd_kernel
    fused_norm_gate.layer_norm_gated_bwd_kernel = layer_norm_gated_bwd_kernel
    fused_norm_gate.layer_norm_gated_fwd = layer_norm_gated_fwd
    fused_norm_gate.layer_norm_gated_bwd = layer_norm_gated_bwd

model ¤

logger module-attribute ¤

logger = logging.getLogger(__name__)

ZeroCentredRMSNorm ¤

Bases: nn.Module

Avoids abnormal amplification of some weights in the original QK-norm. During regularisation and weight decay, weight will be pushed near 0.

See: https://ceramic.ai/blog/zerocentered

Source code in src/microfuel/model.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class ZeroCentredRMSNorm(nn.Module):
    """Avoids abnormal amplification of some weights in the original QK-norm.
    During regularisation and weight decay, `weight` will be pushed near 0.

    See: https://ceramic.ai/blog/zerocentered"""

    def __init__(self, hidden_size: int, eps: float = 1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return (hidden_states * (1.0 + self.weight)).to(input_dtype)
weight instance-attribute ¤
weight = nn.Parameter(torch.zeros(hidden_size))
variance_epsilon instance-attribute ¤
variance_epsilon = eps
__init__ ¤
__init__(hidden_size: int, eps: float = 1e-06)
Source code in src/microfuel/model.py
20
21
22
23
def __init__(self, hidden_size: int, eps: float = 1e-6):
    super().__init__()
    self.weight = nn.Parameter(torch.zeros(hidden_size))
    self.variance_epsilon = eps
forward ¤
forward(hidden_states: torch.Tensor) -> torch.Tensor
Source code in src/microfuel/model.py
25
26
27
28
29
30
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
    input_dtype = hidden_states.dtype
    hidden_states = hidden_states.to(torch.float32)
    variance = hidden_states.pow(2).mean(-1, keepdim=True)
    hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
    return (hidden_states * (1.0 + self.weight)).to(input_dtype)

Pooler ¤

Bases: nn.Module

Source code in src/microfuel/model.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
class Pooler(nn.Module):
    def __init__(self, mode: Literal["mean", "last"] = "last"):
        super().__init__()
        self.mode = mode

    def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
        if self.mode == "last":
            return x[cu_seqlens[1:] - 1]
        elif self.mode == "mean":
            indices = torch.arange(x.size(0), device=x.device)
            return torch.nn.functional.embedding_bag(
                indices, x, offsets=cu_seqlens[:-1], mode="mean"
            )
        else:
            raise ValueError(f"unknown {self.mode=}")
mode instance-attribute ¤
mode = mode
__init__ ¤
__init__(mode: Literal['mean', 'last'] = 'last')
Source code in src/microfuel/model.py
34
35
36
def __init__(self, mode: Literal["mean", "last"] = "last"):
    super().__init__()
    self.mode = mode
forward ¤
forward(
    x: torch.Tensor, cu_seqlens: torch.Tensor
) -> torch.Tensor
Source code in src/microfuel/model.py
38
39
40
41
42
43
44
45
46
47
def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
    if self.mode == "last":
        return x[cu_seqlens[1:] - 1]
    elif self.mode == "mean":
        indices = torch.arange(x.size(0), device=x.device)
        return torch.nn.functional.embedding_bag(
            indices, x, offsets=cu_seqlens[:-1], mode="mean"
        )
    else:
        raise ValueError(f"unknown {self.mode=}")

LinearAttentionBlock ¤

Bases: nn.Module

Source code in src/microfuel/model.py
50
51
52
53
54
55
56
57
58
59
60
61
62
class LinearAttentionBlock(nn.Module):
    def __init__(self, hidden_size: int, num_heads: int, head_dim: int):
        super().__init__()
        self.norm = ZeroCentredRMSNorm(hidden_size)
        self.gdn = GatedDeltaNet(hidden_size=hidden_size, num_heads=num_heads, head_dim=head_dim)

    def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
        residual = x
        x = self.norm(x)
        x, _, _ = self.gdn(x.unsqueeze(0), cu_seqlens=cu_seqlens)
        x = x.squeeze(0)
        x = x + residual
        return x
norm instance-attribute ¤
norm = ZeroCentredRMSNorm(hidden_size)
gdn instance-attribute ¤
gdn = GatedDeltaNet(
    hidden_size=hidden_size,
    num_heads=num_heads,
    head_dim=head_dim,
)
__init__ ¤
__init__(hidden_size: int, num_heads: int, head_dim: int)
Source code in src/microfuel/model.py
51
52
53
54
def __init__(self, hidden_size: int, num_heads: int, head_dim: int):
    super().__init__()
    self.norm = ZeroCentredRMSNorm(hidden_size)
    self.gdn = GatedDeltaNet(hidden_size=hidden_size, num_heads=num_heads, head_dim=head_dim)
forward ¤
forward(
    x: torch.Tensor, cu_seqlens: torch.Tensor
) -> torch.Tensor
Source code in src/microfuel/model.py
56
57
58
59
60
61
62
def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
    residual = x
    x = self.norm(x)
    x, _, _ = self.gdn(x.unsqueeze(0), cu_seqlens=cu_seqlens)
    x = x.squeeze(0)
    x = x + residual
    return x

StaticHyperNet ¤

Bases: nn.Module

Creates a specialised feature extractor for each aircraft type, improving over feature conditioning (concatenating embeddings to input).

See: https://arxiv.org/pdf/1609.09106#page=3 (Section 3.1).

Source code in src/microfuel/model.py
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
class StaticHyperNet(nn.Module):
    """Creates a specialised feature extractor for each aircraft type, improving over
    feature conditioning (concatenating embeddings to input).

    See: https://arxiv.org/pdf/1609.09106#page=3 (Section 3.1)."""

    def __init__(
        self, num_aircraft_types: int, embedding_dim: int, input_dim: int, output_dim: int
    ):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.embedding = nn.Embedding(num_aircraft_types, embedding_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embedding_dim, 64),
            nn.GELU(),
            nn.Linear(64, (input_dim * output_dim) + output_dim),
        )

    def forward(
        self, aircraft_type_idx: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        embeddings = self.embedding(aircraft_type_idx)
        params = self.mlp(embeddings)
        weights_flat = params[:, : self.input_dim * self.output_dim]
        bias = params[:, self.input_dim * self.output_dim :]
        weights = weights_flat.view(-1, self.output_dim, self.input_dim)
        return weights, bias, embeddings
input_dim instance-attribute ¤
input_dim = input_dim
output_dim instance-attribute ¤
output_dim = output_dim
embedding instance-attribute ¤
embedding = nn.Embedding(num_aircraft_types, embedding_dim)
mlp instance-attribute ¤
mlp = nn.Sequential(
    nn.Linear(embedding_dim, 64),
    nn.GELU(),
    nn.Linear(64, input_dim * output_dim + output_dim),
)
__init__ ¤
__init__(
    num_aircraft_types: int,
    embedding_dim: int,
    input_dim: int,
    output_dim: int,
)
Source code in src/microfuel/model.py
71
72
73
74
75
76
77
78
79
80
81
82
def __init__(
    self, num_aircraft_types: int, embedding_dim: int, input_dim: int, output_dim: int
):
    super().__init__()
    self.input_dim = input_dim
    self.output_dim = output_dim
    self.embedding = nn.Embedding(num_aircraft_types, embedding_dim)
    self.mlp = nn.Sequential(
        nn.Linear(embedding_dim, 64),
        nn.GELU(),
        nn.Linear(64, (input_dim * output_dim) + output_dim),
    )
forward ¤
forward(
    aircraft_type_idx: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]
Source code in src/microfuel/model.py
84
85
86
87
88
89
90
91
92
def forward(
    self, aircraft_type_idx: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    embeddings = self.embedding(aircraft_type_idx)
    params = self.mlp(embeddings)
    weights_flat = params[:, : self.input_dim * self.output_dim]
    bias = params[:, self.input_dim * self.output_dim :]
    weights = weights_flat.view(-1, self.output_dim, self.input_dim)
    return weights, bias, embeddings

FuelBurnPredictorConfig dataclass ¤

Source code in src/microfuel/model.py
 95
 96
 97
 98
 99
100
101
102
103
@dataclass
class FuelBurnPredictorConfig:
    input_dim: int
    hidden_size: int
    num_heads: int
    num_aircraft_types: int
    aircraft_embedding_dim: int
    num_layers: int
    pooler_mode: Literal["mean", "last"]  # TODO: get rid of this, only support last
input_dim instance-attribute ¤
input_dim: int
hidden_size instance-attribute ¤
hidden_size: int
num_heads instance-attribute ¤
num_heads: int
num_aircraft_types instance-attribute ¤
num_aircraft_types: int
aircraft_embedding_dim instance-attribute ¤
aircraft_embedding_dim: int
num_layers instance-attribute ¤
num_layers: int
pooler_mode instance-attribute ¤
pooler_mode: Literal['mean', 'last']
__init__ ¤
__init__(
    input_dim: int,
    hidden_size: int,
    num_heads: int,
    num_aircraft_types: int,
    aircraft_embedding_dim: int,
    num_layers: int,
    pooler_mode: Literal["mean", "last"],
) -> None

FuelBurnPredictor ¤

Bases: nn.Module

Gated Delta Network for fuel burn estimation.

It processes data at two resolutions to solve the mass identifiability problem:

  1. Processes the high-fidelity kinematics \(x_{t:t+\Delta}\) for the specific query interval.
  2. Processes the entire trajectory \(x_{0:T}\) (takeoff to landing).

Hypothesis: The pooled_flight vector acts as a compressed context containing implicit estimates of the aircraft's takeoff mass and degradation factors, which are globally observable over the full flight but locally unobservable.

NOTE: Instead of padding, sequences are tightly packed together in a long tensor, and FLA is informed of boundaries via the cu_seqlens tensor.

Source code in src/microfuel/model.py
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
class FuelBurnPredictor(nn.Module):
    r"""Gated Delta Network for fuel burn estimation.

    It processes data at two resolutions to solve the mass identifiability problem:

    1. Processes the high-fidelity kinematics $x_{t:t+\Delta}$ for the specific query interval.
    2. Processes the entire trajectory $x_{0:T}$ (takeoff to landing).

    Hypothesis:
        The `pooled_flight` vector acts as a compressed context containing
        implicit estimates of the aircraft's takeoff mass and degradation factors,
        which are globally observable over the full flight but locally unobservable.

    NOTE: Instead of padding, sequences are tightly packed together in a long tensor, and
    FLA is informed of boundaries via the `cu_seqlens` tensor.
    """

    def __init__(self, cfg: FuelBurnPredictorConfig):
        super().__init__()
        self.config = cfg
        key_dim = int(cfg.hidden_size * 0.75)  # as per docstring
        assert key_dim % cfg.num_heads == 0, (
            "int(hidden_size * 0.75) must be divisible by num_heads (use_gate=True)"
        )
        head_dim = key_dim // cfg.num_heads

        # segment processing branch
        self.hypernetwork_segment = StaticHyperNet(
            num_aircraft_types=cfg.num_aircraft_types,
            embedding_dim=cfg.aircraft_embedding_dim,
            input_dim=cfg.input_dim,
            output_dim=cfg.hidden_size,
        )
        self.layers_segment = nn.ModuleList(
            [
                LinearAttentionBlock(cfg.hidden_size, cfg.num_heads, head_dim)
                for _ in range(cfg.num_layers)
            ]
        )
        self.pooler_segment = Pooler(mode=cfg.pooler_mode)

        # flight context processing branch
        self.hypernetwork_flight = StaticHyperNet(
            num_aircraft_types=cfg.num_aircraft_types,
            embedding_dim=cfg.aircraft_embedding_dim,
            input_dim=cfg.input_dim,
            output_dim=cfg.hidden_size,
        )
        self.layers_flight = nn.ModuleList(
            [
                LinearAttentionBlock(cfg.hidden_size, cfg.num_heads, head_dim)
                for _ in range(cfg.num_layers)
            ]
        )
        self.pooler_flight = Pooler(mode=cfg.pooler_mode)

        # share embedding layer between hypernetworks
        self.hypernetwork_flight.embedding = self.hypernetwork_segment.embedding

        self.regression_head = nn.Linear(
            cfg.hidden_size + cfg.hidden_size + cfg.aircraft_embedding_dim, 1
        )

    def forward(
        self,
        x_flight: torch.Tensor,
        cu_seqlens_flight: torch.Tensor,
        x_segment: torch.Tensor,
        cu_seqlens_segment: torch.Tensor,
        aircraft_type_idx: torch.Tensor,
    ) -> torch.Tensor:
        """:param x_flight: packed tensor of full flight trajectories
        :param cu_seqlens_flight: cumulative sequence lengths for flight tensor
        :param x_segment: packed tensor of trajectory segments for prediction
        :param cu_seqlens_segment: cumulative sequence lengths for segment tensor
        :param aircraft_type_idx: (B,) tensor of aircraft type indices"""
        # segment processing
        segment_lengths = cu_seqlens_segment[1:] - cu_seqlens_segment[:-1]
        weights_s, bias_s, ac_embeddings = self.hypernetwork_segment(aircraft_type_idx)
        weights_expanded_s = torch.repeat_interleave(weights_s, segment_lengths, dim=0)
        bias_expanded_s = torch.repeat_interleave(bias_s, segment_lengths, dim=0)
        x_s = torch.bmm(weights_expanded_s, x_segment.unsqueeze(-1)).squeeze(-1) + bias_expanded_s

        for layer in self.layers_segment:
            x_s = layer(x_s, cu_seqlens_segment)
        pooled_segment = self.pooler_segment(x_s, cu_seqlens_segment)

        # flight context processing
        flight_lengths = cu_seqlens_flight[1:] - cu_seqlens_flight[:-1]
        weights_f, bias_f, _ = self.hypernetwork_flight(aircraft_type_idx)
        weights_expanded_f = torch.repeat_interleave(weights_f, flight_lengths, dim=0)
        bias_expanded_f = torch.repeat_interleave(bias_f, flight_lengths, dim=0)
        x_f = torch.bmm(weights_expanded_f, x_flight.unsqueeze(-1)).squeeze(-1) + bias_expanded_f

        for layer in self.layers_flight:
            x_f = layer(x_f, cu_seqlens_flight)
        pooled_flight = self.pooler_flight(x_f, cu_seqlens_flight)

        # final regression
        combined_features = torch.cat([pooled_segment, pooled_flight, ac_embeddings], dim=1)
        y_pred = self.regression_head(combined_features)
        return y_pred
config instance-attribute ¤
config = cfg
hypernetwork_segment instance-attribute ¤
hypernetwork_segment = StaticHyperNet(
    num_aircraft_types=cfg.num_aircraft_types,
    embedding_dim=cfg.aircraft_embedding_dim,
    input_dim=cfg.input_dim,
    output_dim=cfg.hidden_size,
)
layers_segment instance-attribute ¤
layers_segment = nn.ModuleList(
    [
        (
            LinearAttentionBlock(
                cfg.hidden_size, cfg.num_heads, head_dim
            )
        )
        for _ in (range(cfg.num_layers))
    ]
)
pooler_segment instance-attribute ¤
pooler_segment = Pooler(mode=cfg.pooler_mode)
hypernetwork_flight instance-attribute ¤
hypernetwork_flight = StaticHyperNet(
    num_aircraft_types=cfg.num_aircraft_types,
    embedding_dim=cfg.aircraft_embedding_dim,
    input_dim=cfg.input_dim,
    output_dim=cfg.hidden_size,
)
layers_flight instance-attribute ¤
layers_flight = nn.ModuleList(
    [
        (
            LinearAttentionBlock(
                cfg.hidden_size, cfg.num_heads, head_dim
            )
        )
        for _ in (range(cfg.num_layers))
    ]
)
pooler_flight instance-attribute ¤
pooler_flight = Pooler(mode=cfg.pooler_mode)
regression_head instance-attribute ¤
regression_head = nn.Linear(
    cfg.hidden_size
    + cfg.hidden_size
    + cfg.aircraft_embedding_dim,
    1,
)
__init__ ¤
__init__(cfg: FuelBurnPredictorConfig)
Source code in src/microfuel/model.py
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
def __init__(self, cfg: FuelBurnPredictorConfig):
    super().__init__()
    self.config = cfg
    key_dim = int(cfg.hidden_size * 0.75)  # as per docstring
    assert key_dim % cfg.num_heads == 0, (
        "int(hidden_size * 0.75) must be divisible by num_heads (use_gate=True)"
    )
    head_dim = key_dim // cfg.num_heads

    # segment processing branch
    self.hypernetwork_segment = StaticHyperNet(
        num_aircraft_types=cfg.num_aircraft_types,
        embedding_dim=cfg.aircraft_embedding_dim,
        input_dim=cfg.input_dim,
        output_dim=cfg.hidden_size,
    )
    self.layers_segment = nn.ModuleList(
        [
            LinearAttentionBlock(cfg.hidden_size, cfg.num_heads, head_dim)
            for _ in range(cfg.num_layers)
        ]
    )
    self.pooler_segment = Pooler(mode=cfg.pooler_mode)

    # flight context processing branch
    self.hypernetwork_flight = StaticHyperNet(
        num_aircraft_types=cfg.num_aircraft_types,
        embedding_dim=cfg.aircraft_embedding_dim,
        input_dim=cfg.input_dim,
        output_dim=cfg.hidden_size,
    )
    self.layers_flight = nn.ModuleList(
        [
            LinearAttentionBlock(cfg.hidden_size, cfg.num_heads, head_dim)
            for _ in range(cfg.num_layers)
        ]
    )
    self.pooler_flight = Pooler(mode=cfg.pooler_mode)

    # share embedding layer between hypernetworks
    self.hypernetwork_flight.embedding = self.hypernetwork_segment.embedding

    self.regression_head = nn.Linear(
        cfg.hidden_size + cfg.hidden_size + cfg.aircraft_embedding_dim, 1
    )
forward ¤
forward(
    x_flight: torch.Tensor,
    cu_seqlens_flight: torch.Tensor,
    x_segment: torch.Tensor,
    cu_seqlens_segment: torch.Tensor,
    aircraft_type_idx: torch.Tensor,
) -> torch.Tensor

Parameters:

Name Type Description Default
x_flight torch.Tensor

packed tensor of full flight trajectories

required
cu_seqlens_flight torch.Tensor

cumulative sequence lengths for flight tensor

required
x_segment torch.Tensor

packed tensor of trajectory segments for prediction

required
cu_seqlens_segment torch.Tensor

cumulative sequence lengths for segment tensor

required
aircraft_type_idx torch.Tensor

(B,) tensor of aircraft type indices

required
Source code in src/microfuel/model.py
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
def forward(
    self,
    x_flight: torch.Tensor,
    cu_seqlens_flight: torch.Tensor,
    x_segment: torch.Tensor,
    cu_seqlens_segment: torch.Tensor,
    aircraft_type_idx: torch.Tensor,
) -> torch.Tensor:
    """:param x_flight: packed tensor of full flight trajectories
    :param cu_seqlens_flight: cumulative sequence lengths for flight tensor
    :param x_segment: packed tensor of trajectory segments for prediction
    :param cu_seqlens_segment: cumulative sequence lengths for segment tensor
    :param aircraft_type_idx: (B,) tensor of aircraft type indices"""
    # segment processing
    segment_lengths = cu_seqlens_segment[1:] - cu_seqlens_segment[:-1]
    weights_s, bias_s, ac_embeddings = self.hypernetwork_segment(aircraft_type_idx)
    weights_expanded_s = torch.repeat_interleave(weights_s, segment_lengths, dim=0)
    bias_expanded_s = torch.repeat_interleave(bias_s, segment_lengths, dim=0)
    x_s = torch.bmm(weights_expanded_s, x_segment.unsqueeze(-1)).squeeze(-1) + bias_expanded_s

    for layer in self.layers_segment:
        x_s = layer(x_s, cu_seqlens_segment)
    pooled_segment = self.pooler_segment(x_s, cu_seqlens_segment)

    # flight context processing
    flight_lengths = cu_seqlens_flight[1:] - cu_seqlens_flight[:-1]
    weights_f, bias_f, _ = self.hypernetwork_flight(aircraft_type_idx)
    weights_expanded_f = torch.repeat_interleave(weights_f, flight_lengths, dim=0)
    bias_expanded_f = torch.repeat_interleave(bias_f, flight_lengths, dim=0)
    x_f = torch.bmm(weights_expanded_f, x_flight.unsqueeze(-1)).squeeze(-1) + bias_expanded_f

    for layer in self.layers_flight:
        x_f = layer(x_f, cu_seqlens_flight)
    pooled_flight = self.pooler_flight(x_f, cu_seqlens_flight)

    # final regression
    combined_features = torch.cat([pooled_segment, pooled_flight, ac_embeddings], dim=1)
    y_pred = self.regression_head(combined_features)
    return y_pred

plot ¤

default_fig ¤

default_fig(*args, **kwargs) -> Figure
Source code in src/microfuel/plot.py
20
21
22
23
24
25
26
def default_fig(*args, **kwargs) -> Figure:
    _init_style(dark=False)
    if "figsize" not in kwargs:
        kwargs["figsize"] = (12, 7)
    fig = plt.figure(*args, **kwargs)
    fig.set_layout_engine("tight")
    return fig