Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
6b12f87
feat(core): add MaskExpression POJO and projection support for ReadRel
flex-seongmin Mar 26, 2026
8bd8942
test(core): add MaskExpression roundtrip tests for nested list, map, …
flex-seongmin Mar 26, 2026
b723737
chore(core): format code for spotlessJavaApply lint
flex-seongmin Mar 26, 2026
ab49ad6
fix(core): apply MaskExpression projection to ReadRel deriveRecordType
flex-seongmin Mar 30, 2026
2888f2d
chore(core): apply lint
flex-seongmin Apr 2, 2026
821a424
refactor(core): PR: replace MaskExpression Select union with visitor …
flex-seongmin Apr 2, 2026
ecfd05f
refactor(core): PR: split MaskExpressionProtoConverter into toProto a…
flex-seongmin Apr 2, 2026
2330204
refactor(core): PR: merge MaskExpr into MaskExpression interface
flex-seongmin Apr 2, 2026
deb0469
docs(core): add javadoc to MaskExpression proto converter classes
flex-seongmin Apr 8, 2026
b010578
refactor(core): use MaskExpressionVisitor for Select dispatch and cle…
flex-seongmin Apr 8, 2026
a285952
docs(core): add javadoc to MaskExpression
flex-seongmin Apr 8, 2026
70aa512
docs(core): rewrite javadoc
flex-seongmin Apr 8, 2026
4ca34b7
chore(core): renaming argument of projectStruct function
flex-seongmin Apr 8, 2026
05a38f2
chroe(core): renaming argument of project function and add javadoc
flex-seongmin Apr 8, 2026
3adb196
refactor(core): remove MaskExpression.Mask by combining @Value.Immuta…
flex-seongmin Apr 8, 2026
dd9cac0
docs(core): add javadoc return to MaskExpression interface
flex-seongmin Apr 9, 2026
375704c
docs(core): lint javadoc
flex-seongmin Apr 10, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
422 changes: 422 additions & 0 deletions core/src/main/java/io/substrait/expression/MaskExpression.java

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
package io.substrait.expression;

import io.substrait.type.Type;
import io.substrait.type.TypeCreator;
import io.substrait.util.EmptyVisitationContext;
import java.util.List;

/**
* Applies a {@link MaskExpression} projection to a {@link Type.Struct}, returning a pruned struct.
*/
public final class MaskExpressionTypeProjector {

private MaskExpressionTypeProjector() {}

/**
* Applies the given projection to a struct type, returning a pruned struct.
*
* @param projection the mask expression projection
* @param structType the struct type to project
* @return a pruned struct containing only the selected fields
*/
public static Type.Struct project(MaskExpression projection, Type.Struct structType) {
return projectStruct(projection.getSelect(), structType);
}

private static Type.Struct projectStruct(
MaskExpression.StructSelect structSelect, Type.Struct structType) {
List<Type> fields = structType.fields();
List<MaskExpression.StructItem> items = structSelect.getStructItems();

return TypeCreator.of(structType.nullable())
.struct(items.stream().map(item -> projectItem(item, fields.get(item.getField()))));
}

private static Type projectItem(MaskExpression.StructItem item, Type fieldType) {
if (!item.getChild().isPresent()) {
return fieldType;
}

MaskExpression.Select select = item.getChild().get();

return select.accept(
new MaskExpressionVisitor<Type, EmptyVisitationContext, RuntimeException>() {
@Override
public Type visit(
MaskExpression.StructSelect structSelect, EmptyVisitationContext context) {
return projectStruct(structSelect, (Type.Struct) fieldType);
}

@Override
public Type visit(MaskExpression.ListSelect listSelect, EmptyVisitationContext context) {
return projectList(listSelect, (Type.ListType) fieldType);
}

@Override
public Type visit(MaskExpression.MapSelect mapSelect, EmptyVisitationContext context) {
return projectMap(mapSelect, (Type.Map) fieldType);
}
},
EmptyVisitationContext.INSTANCE);
}

private static Type.ListType projectList(
MaskExpression.ListSelect listSelect, Type.ListType listType) {
if (!listSelect.getChild().isPresent()) {
return listType;
}

MaskExpression.Select childSelect = listSelect.getChild().get();
Type elementType = listType.elementType();

return childSelect.accept(
new MaskExpressionVisitor<Type.ListType, EmptyVisitationContext, RuntimeException>() {
@Override
public Type.ListType visit(
MaskExpression.StructSelect structSelect, EmptyVisitationContext context) {
if (elementType instanceof Type.Struct) {
Type.Struct prunedElement = projectStruct(structSelect, (Type.Struct) elementType);
return TypeCreator.of(listType.nullable()).list(prunedElement);
}
return listType;
}

@Override
public Type.ListType visit(
MaskExpression.ListSelect listSelect, EmptyVisitationContext context) {
return listType;
}

@Override
public Type.ListType visit(
MaskExpression.MapSelect mapSelect, EmptyVisitationContext context) {
return listType;
}
},
EmptyVisitationContext.INSTANCE);
}

private static Type.Map projectMap(MaskExpression.MapSelect mapSelect, Type.Map mapType) {
if (!mapSelect.getChild().isPresent()) {
return mapType;
}

MaskExpression.Select childSelect = mapSelect.getChild().get();
Type valueType = mapType.value();

return childSelect.accept(
new MaskExpressionVisitor<Type.Map, EmptyVisitationContext, RuntimeException>() {
@Override
public Type.Map visit(
MaskExpression.StructSelect structSelect, EmptyVisitationContext context) {
if (valueType instanceof Type.Struct) {
Type.Struct prunedValue = projectStruct(structSelect, (Type.Struct) valueType);
return TypeCreator.of(mapType.nullable()).map(mapType.key(), prunedValue);
}
return mapType;
}

@Override
public Type.Map visit(
MaskExpression.ListSelect listSelect, EmptyVisitationContext context) {
return mapType;
}

@Override
public Type.Map visit(
MaskExpression.MapSelect mapSelect, EmptyVisitationContext context) {
return mapType;
}
},
EmptyVisitationContext.INSTANCE);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package io.substrait.expression;

import io.substrait.util.VisitationContext;

/**
* Visitor for {@link MaskExpression} select nodes.
*
* @param <R> result type returned by each visit
* @param <C> visitation context type
* @param <E> throwable type that visit methods may throw
*/
public interface MaskExpressionVisitor<R, C extends VisitationContext, E extends Throwable> {

/**
* Visit a struct select.
*
* @param structSelect the struct select
* @param context visitation context
* @return visit result
* @throws E on visit failure
*/
R visit(MaskExpression.StructSelect structSelect, C context) throws E;

/**
* Visit a list select.
*
* @param listSelect the list select
* @param context visitation context
* @return visit result
* @throws E on visit failure
*/
R visit(MaskExpression.ListSelect listSelect, C context) throws E;

/**
* Visit a map select.
*
* @param mapSelect the map select
* @param context visitation context
* @return visit result
* @throws E on visit failure
*/
R visit(MaskExpression.MapSelect mapSelect, C context) throws E;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
package io.substrait.expression.proto;

import io.substrait.expression.MaskExpression;
import io.substrait.expression.MaskExpression.ListSelect;
import io.substrait.expression.MaskExpression.ListSelectItem;
import io.substrait.expression.MaskExpression.MapSelect;
import io.substrait.expression.MaskExpression.Select;
import io.substrait.expression.MaskExpression.StructItem;
import io.substrait.expression.MaskExpression.StructSelect;
import io.substrait.expression.MaskExpressionVisitor;
import io.substrait.proto.Expression;
import io.substrait.util.EmptyVisitationContext;

/**
* Converts from {@link io.substrait.expression.MaskExpression} to {@link Expression.MaskExpression}
*/
public final class MaskExpressionProtoConverter {

private MaskExpressionProtoConverter() {}

private static final MaskExpressionVisitor<
Expression.MaskExpression.Select, EmptyVisitationContext, RuntimeException>
SELECT_TO_PROTO_VISITOR =
new MaskExpressionVisitor<
Expression.MaskExpression.Select, EmptyVisitationContext, RuntimeException>() {
@Override
public Expression.MaskExpression.Select visit(
MaskExpression.StructSelect structSelect, EmptyVisitationContext context) {
return Expression.MaskExpression.Select.newBuilder()
.setStruct(toProto(structSelect))
.build();
}

@Override
public Expression.MaskExpression.Select visit(
MaskExpression.ListSelect listSelect, EmptyVisitationContext context) {
return Expression.MaskExpression.Select.newBuilder()
.setList(toProtoListSelect(listSelect))
.build();
}

@Override
public Expression.MaskExpression.Select visit(
MaskExpression.MapSelect mapSelect, EmptyVisitationContext context) {
return Expression.MaskExpression.Select.newBuilder()
.setMap(toProtoMapSelect(mapSelect))
.build();
}
};

/**
* Converts a POJO {@link MaskExpression} to its proto representation.
*
* @param mask the POJO {@link MaskExpression}
* @return the proto {@link Expression.MaskExpression}
*/
public static Expression.MaskExpression toProto(MaskExpression mask) {
return Expression.MaskExpression.newBuilder()
.setSelect(toProto(mask.getSelect()))
.setMaintainSingularStruct(mask.getMaintainSingularStruct())
.build();
}

private static Expression.MaskExpression.StructSelect toProto(StructSelect structSelect) {
Expression.MaskExpression.StructSelect.Builder builder =
Expression.MaskExpression.StructSelect.newBuilder();
for (StructItem item : structSelect.getStructItems()) {
builder.addStructItems(toProto(item));
}
return builder.build();
}

private static Expression.MaskExpression.StructItem toProto(StructItem structItem) {
Expression.MaskExpression.StructItem.Builder builder =
Expression.MaskExpression.StructItem.newBuilder().setField(structItem.getField());
structItem.getChild().ifPresent(child -> builder.setChild(toProtoSelect(child)));
return builder.build();
}

private static Expression.MaskExpression.Select toProtoSelect(Select select) {
return select.accept(SELECT_TO_PROTO_VISITOR, EmptyVisitationContext.INSTANCE);
}

private static Expression.MaskExpression.ListSelect toProtoListSelect(ListSelect listSelect) {
Expression.MaskExpression.ListSelect.Builder builder =
Expression.MaskExpression.ListSelect.newBuilder();
for (ListSelectItem item : listSelect.getSelection()) {
builder.addSelection(toProtoListSelectItem(item));
}
listSelect.getChild().ifPresent(child -> builder.setChild(toProtoSelect(child)));
return builder.build();
}

private static Expression.MaskExpression.ListSelect.ListSelectItem toProtoListSelectItem(
ListSelectItem item) {
Expression.MaskExpression.ListSelect.ListSelectItem.Builder builder =
Expression.MaskExpression.ListSelect.ListSelectItem.newBuilder();
if (item.getItem().isPresent()) {
builder.setItem(
Expression.MaskExpression.ListSelect.ListSelectItem.ListElement.newBuilder()
.setField(item.getItem().get().getField())
.build());
} else if (item.getSlice().isPresent()) {
builder.setSlice(
Expression.MaskExpression.ListSelect.ListSelectItem.ListSlice.newBuilder()
.setStart(item.getSlice().get().getStart())
.setEnd(item.getSlice().get().getEnd())
.build());
} else {
throw new IllegalArgumentException("ListSelectItem must have either item or slice set");
}
return builder.build();
}

private static Expression.MaskExpression.MapSelect toProtoMapSelect(MapSelect mapSelect) {
Expression.MaskExpression.MapSelect.Builder builder =
Expression.MaskExpression.MapSelect.newBuilder();
mapSelect
.getKey()
.ifPresent(
key ->
builder.setKey(
Expression.MaskExpression.MapSelect.MapKey.newBuilder()
.setMapKey(key.getMapKey())
.build()));
mapSelect
.getExpression()
.ifPresent(
expr ->
builder.setExpression(
Expression.MaskExpression.MapSelect.MapKeyExpression.newBuilder()
.setMapKeyExpression(expr.getMapKeyExpression())
.build()));
mapSelect.getChild().ifPresent(child -> builder.setChild(toProtoSelect(child)));
return builder.build();
}
}
Loading
Loading