rhl
07/20/2022, 7:32 PMDaniel Gafni
07/20/2022, 8:53 PMclass NDaysPartitionMapping(PartitionMapping):
def __init__(self, days: int, offset: int = 0):
self.days = days
self.offset = offset
def get_downstream_partitions_for_partition_range(
self,
upstream_partition_key_range: PartitionKeyRange,
downstream_partitions_def: PartitionsDefinition,
upstream_partitions_def: PartitionsDefinition,
) -> PartitionKeyRange:
assert isinstance(upstream_partitions_def, DailyPartitionsDefinition)
return upstream_partition_key_range
def get_upstream_partitions_for_partition_range(
self,
downstream_partition_key_range: PartitionKeyRange,
downstream_partitions_def: PartitionsDefinition, # pylint: disable=unused-argument
upstream_partitions_def: PartitionsDefinition, # pylint: disable=unused-argument
) -> PartitionKeyRange:
mapped_range = PartitionKeyRange(
start=(
datetime.strptime(downstream_partition_key_range.start, "%Y-%m-%d")
- timedelta(days=self.days)
- timedelta(days=self.offset)
).strftime("%Y-%m-%d"),
end=(
datetime.strptime(downstream_partition_key_range.end, "%Y-%m-%d") - timedelta(days=self.offset)
).strftime("%Y-%m-%d"),
)
return mapped_range
Asset using it:
@asset(
partitions_def=raw_partitions,
group_name="catalog_ranking",
partition_mappings={"ranking_model": NDaysPartitionMapping(days=0, offset=2)},
io_manager_key="parquet_io_manager",
ins={"candidates_with_features": AssetIn(metadata={"columns": ALL_COLUMNS})},
)
def scored_candidates(ranking_model: CatBoost, candidates_with_features: pl.DataFrame) -> pl.DataFrame:
Support partition mappings inside your IO Manager (the most important part) :
def load_input(self, context: InputContext) -> Union[pl.DataFrame, List[pl.DataFrame]]:
# In this load_input function, we vary the behavior based on the type of the downstream input
input_type = context.dagster_type.typing_type
path = self.get_path(context)
columns = context.metadata.get("columns")
if columns is not None:
context.log.debug(f"{self.__class__} received metadata value columns={columns}")
allow_missing_partitions = context.metadata.get("allow_missing_partitions", False)
if input_type == pl.DataFrame:
context.log.debug(f"Loading DataFrame from {path}")
df = load_parquet_v2(path, columns=columns)
context.add_input_metadata({"path": MetadataValue.path(path)})
return df
elif input_type == List[pl.DataFrame]:
# load multiple partitions
if not context.has_asset_partitions:
raise TypeError(f"Detected {input_type} input type but the asset is not partitioned")
else:
<http://context.log.info|context.log.info>("partitioning info")
<http://context.log.info|context.log.info>(context.asset_partitions_def.get_partition_keys())
# <http://context.log.info|context.log.info>(context.asset_partitions_def.get_partition_keys())
range_start, range_end = context.asset_partition_key_range
# <http://context.log.info|context.log.info>(context.upstream_output.asset_info.partitions_def)
base_dir = os.path.dirname(path)
partitions = pl.date_range(
datetime.strptime(range_start, "%Y-%m-%d").date(),
datetime.strptime(range_end, "%Y-%m-%d").date(),
"1d",
).dt.strftime("%Y-%m-%d")
context.log.debug(f"Loading {len(partitions)} partitions")
dfs: List[pl.DataFrame] = []
for partition in partitions:
path_with_partition = os.path.join(base_dir, f"{partition}.pq")
context.log.debug(f"Loading DataFrame partition from {path_with_partition}")
try:
df = load_parquet_v2(path_with_partition, columns=columns)
dfs.append(df)
except FileNotFoundError as e:
if not allow_missing_partitions:
raise e
context.log.debug(f"Couldn't load partition {path_with_partition} and skipped it")
return dfs
else:
return check.failed(
f"Inputs of type {context.dagster_type} not supported. Please specify a valid type "
"for this input either in the op signature or on the corresponding In."
)
Sorry, the example is dirty because I don't have the time to clean it up for you now, but I guess you get the idea.rhl
07/21/2022, 12:58 AM