Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ public void selectExpressionWithoutFrom() {
givenQuery("SELECT 1 + 1")
.assertPlan(
"""
LogicalProject(1 + 1=[+(1, 1)])
LogicalProject(1 + 1=[+(1:BIGINT, 1:BIGINT)])
LogicalValues(tuples=[[{ 0 }]])
""");
}
Expand Down Expand Up @@ -404,7 +404,7 @@ public void testArithmeticOnAggregates() {
givenQuery("SELECT MAX(age) + MIN(age) AS range_sum FROM catalog.employees")
.assertPlan(
"""
LogicalProject(range_sum=[+($0, $1)])
LogicalProject(range_sum=[+(CAST($0):BIGINT, CAST($1):BIGINT)])
LogicalAggregate(group=[{}], MAX(age)=[MAX($0)], MIN(age)=[MIN($0)])
LogicalProject(age=[$2])
LogicalTableScan(table=[[catalog, employees]])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,9 @@ public void apply(RelOptRuleCall call, LogicalAggregate aggregate, LogicalProjec
final Function<RelNode, Function<RexNode, RexNode>> literalConverterProvider;
RexCall rexCall = (RexCall) project.getProjects().get(aggCall.getArgList().getFirst());
if (rexCall.getOperator().kind == SqlKind.PLUS
|| rexCall.getOperator().kind == SqlKind.MINUS) {
|| rexCall.getOperator().kind == SqlKind.MINUS
|| rexCall.getOperator().kind == SqlKind.CHECKED_PLUS
|| rexCall.getOperator().kind == SqlKind.CHECKED_MINUS) {
AggregateCall countCall =
AggregateCall.create(
aggCall.getParserPosition(),
Expand Down Expand Up @@ -280,7 +282,12 @@ private static boolean isCallWithLiteral(RexNode node) {

List<SqlKind> CONVERTABLE_FUNCTIONS =
List.of(
SqlKind.PLUS, SqlKind.MINUS, SqlKind.TIMES
SqlKind.PLUS,
SqlKind.MINUS,
SqlKind.TIMES,
SqlKind.CHECKED_PLUS,
SqlKind.CHECKED_MINUS,
SqlKind.CHECKED_TIMES
// Don't support division because of the issue of integer division
// e.g. (2000 / 3) * 3 = 1998 while 2000 * 3 / 3 = 2000
// SqlKind.DIVIDE
Expand Down
108 changes: 105 additions & 3 deletions core/src/main/java/org/opensearch/sql/executor/QueryService.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,16 @@
import org.apache.calcite.plan.RelTraitDef;
import org.apache.calcite.rel.RelCollation;
import org.apache.calcite.rel.RelCollations;
import org.apache.calcite.rel.RelHomogeneousShuttle;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Sort;
import org.apache.calcite.rel.logical.LogicalSort;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexShuttle;
import org.apache.calcite.schema.SchemaPlus;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.parser.SqlParser;
import org.apache.calcite.tools.FrameworkConfig;
import org.apache.calcite.tools.Frameworks;
Expand Down Expand Up @@ -174,7 +180,9 @@ public void executeWithCalcite(
RelNode calcitePlan =
StageErrorHandler.executeStage(
QueryProcessingStage.PLAN_CONVERSION,
() -> convertToCalcitePlan(relNode, context),
() ->
withCheckedArithmetic(
convertToCalcitePlan(relNode, context), context),
"while converting the query to an executable plan");

analyzeMetric.set(System.nanoTime() - analyzeStart);
Expand All @@ -187,7 +195,15 @@ public void executeWithCalcite(
},
QueryService.class);
} catch (Throwable t) {
if (isCalciteFallbackAllowed(t) && !(t instanceof NonFallbackCalciteException)) {
ArithmeticException overflow = findArithmeticOverflow(t);
if (overflow != null) {
// Checked arithmetic detected integer/long overflow. Surface as a client error
// instead of wrapping (silently) or falling back to the V2 engine.
propagateCalciteError(
new NonFallbackCalciteException(
"Arithmetic overflow: " + overflow.getMessage(), overflow),
listener);
} else if (isCalciteFallbackAllowed(t) && !(t instanceof NonFallbackCalciteException)) {
log.warn("Fallback to V2 query engine since got exception", t);
executeWithLegacy(plan, queryType, listener, Optional.of(t));
} else {
Expand Down Expand Up @@ -227,7 +243,8 @@ public void explainWithCalcite(
context.run(
() -> {
RelNode relNode = analyze(plan, context);
RelNode calcitePlan = convertToCalcitePlan(relNode, context);
RelNode calcitePlan =
withCheckedArithmetic(convertToCalcitePlan(relNode, context), context);
if (format != null) {
executionEngine.explain(calcitePlan, mode, format, context, listener);
} else {
Expand Down Expand Up @@ -383,6 +400,91 @@ private boolean isCalciteEnabled(Settings settings) {
}
}

/**
* Rewrite {@code +}/{@code -}/{@code *} to their overflow-checked variants ({@code CHECKED_PLUS}
* / {@code CHECKED_MINUS} / {@code CHECKED_MULTIPLY}) so integer and long arithmetic overflow
* throws {@link ArithmeticException} (via {@code Math.addExact} etc.) instead of silently
* wrapping. Applied before pushdown so both coordinator-executed and pushed-down (script)
* arithmetic are checked. Floating-point arithmetic is unchanged (IEEE 754).
*
* <p>This does the same rewrite as Calcite's {@code ConvertToChecked} but preserves each call's
* originally inferred type (via {@code makeCall(type, op, operands)}) and touches only the three
* arithmetic operators, so it does not re-derive the types of unrelated calls (e.g. {@code
* CEIL}/{@code DIVIDE}) the way {@code ConvertToChecked} does.
*/
private static RelNode withCheckedArithmetic(RelNode calcitePlan, CalcitePlanContext context) {
RexShuttle checkedShuttle =
new RexShuttle() {
@Override
public RexNode visitCall(RexCall call) {
RexNode visited = super.visitCall(call);
if (!(visited instanceof RexCall rexCall)) {
return visited;
}
SqlOperator checked =
switch (rexCall.getOperator().getKind()) {
case PLUS -> SqlStdOperatorTable.CHECKED_PLUS;
case MINUS -> SqlStdOperatorTable.CHECKED_MINUS;
case TIMES -> SqlStdOperatorTable.CHECKED_MULTIPLY;
default -> null;
};
// Only integer/long arithmetic can overflow silently and has a checked
// implementation (Math.addExact etc.). Float/double/decimal have no checked variant
// (SqlFunctions.checkedMultiply(double,double) does not exist) and follow IEEE 754, so
// leave them untouched.
if (checked == null || !isCheckableIntegerArithmetic(rexCall)) {
return visited;
}
return context.rexBuilder.makeCall(rexCall.getType(), checked, rexCall.getOperands());
}
};
return calcitePlan.accept(
new RelHomogeneousShuttle() {
@Override
public RelNode visit(RelNode other) {
RelNode visited = super.visitChildren(other);
return visited.accept(checkedShuttle);
}
});
}

/**
* Checked arithmetic is applied to BIGINT ({@code long}) operands only. Narrower integer
* arithmetic (byte/short/int) is already widened to a type that cannot overflow before this
* rewrite runs — {@code PPLFuncImpTable} promotes byte/short to INT and any int/long operand to
* BIGINT for {@code +}/{@code -}/{@code *} — so the sole remaining overflow case that reaches the
* Calcite engine is {@code long} arithmetic, which has no wider integer type to widen into.
* Float/double/decimal follow IEEE 754 (or decimal semantics) and have no {@code CHECKED_*}
* runtime (e.g. {@code SqlFunctions.checkedMultiply(double, double)} does not exist), so they are
* left untouched. Require both the result and every operand to be BIGINT.
*/
private static boolean isCheckableIntegerArithmetic(RexCall call) {
if (!isCheckableLongType(call.getType())) {
return false;
}
return call.getOperands().stream().allMatch(op -> isCheckableLongType(op.getType()));
}

private static boolean isCheckableLongType(org.apache.calcite.rel.type.RelDataType type) {
return type.getSqlTypeName() == org.apache.calcite.sql.type.SqlTypeName.BIGINT;
}

/**
* Walk the cause chain to find an {@link ArithmeticException} raised by checked arithmetic. Row-
* level overflow surfaces wrapped (SQLException -&gt; RuntimeException -&gt; ErrorReport), so a
* top-level {@code catch (ArithmeticException)} is insufficient.
*/
private static ArithmeticException findArithmeticOverflow(@Nullable Throwable t) {
for (Throwable cause = t;
cause != null && cause != cause.getCause();
cause = cause.getCause()) {
if (cause instanceof ArithmeticException arithmeticException) {
return arithmeticException;
}
}
return null;
}

// TODO https://github.com/opensearch-project/sql/issues/3457
// Calcite is not available for SQL query now. Maybe release in 3.1.0?
private boolean shouldUseCalcite(QueryType queryType) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,37 +5,72 @@

package org.opensearch.sql.expression.function.CollectionUDF;

import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.List;
import org.apache.calcite.sql.type.SqlTypeName;

/** Core logic for `mvappend` command to collect elements from list of args */
public class MVAppendCore {

/**
* Collect non-null elements from `args`. If an item is a list, it will collect non-null elements
* of the list. See {@ref MVAppendFunctionImplTest} for detailed behavior.
* of the list. Each collected element is coerced to {@code elementType} so a heterogeneously
* boxed input (e.g. an {@code array(int_col)} operand contributing {@code Integer} cells to a
* {@code BIGINT}-typed result) does not throw {@code ClassCastException} when the array is later
* materialized by Avatica's per-type accessor. See {@ref MVAppendFunctionImplTest} for detailed
* behavior.
*/
/** Untyped overload — collects without element coercion (used by map-append and unit tests). */
public static List<Object> collectElements(Object... args) {
return collectElements((SqlTypeName) null, args);
}

public static List<Object> collectElements(SqlTypeName elementType, Object... args) {
List<Object> elements = new ArrayList<>();

for (Object arg : args) {
if (arg == null) {
continue;
} else if (arg instanceof List) {
addListElements((List<?>) arg, elements);
addListElements((List<?>) arg, elements, elementType);
} else {
elements.add(arg);
elements.add(coerce(arg, elementType));
}
}

return elements.isEmpty() ? null : elements;
}

private static void addListElements(List<?> list, List<Object> elements) {
private static void addListElements(
List<?> list, List<Object> elements, SqlTypeName elementType) {
for (Object item : list) {
if (item != null) {
elements.add(item);
elements.add(coerce(item, elementType));
}
}
}

/**
* Align a boxed numeric element to the array's target element type. Only numeric widenings that
* arise from operand widening (e.g. INTEGER cells into a BIGINT array) are handled; non-numeric
* or null-typed targets pass the value through unchanged so mixed / ANY-typed arrays keep their
* existing {@code Object[]} runtime semantics.
*/
private static Object coerce(Object value, SqlTypeName elementType) {
if (elementType == null || !(value instanceof Number)) {
return value;
}
Number num = (Number) value;
return switch (elementType) {
case TINYINT -> num.byteValue();
case SMALLINT -> num.shortValue();
case INTEGER -> num.intValue();
case BIGINT -> num.longValue();
case FLOAT, REAL -> num.floatValue();
case DOUBLE -> num.doubleValue();
case DECIMAL -> num instanceof BigDecimal ? num : BigDecimal.valueOf(num.doubleValue());
default -> value;
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,27 @@ public Expression implement(
coerced.add(EnumUtils.convert(op, elementClass));
}
}
// Pass the target element SqlTypeName so the runtime can align the elements flattened out of
// ARRAY operands. Calcite does not element-wise cast inside an array operand, so
// `mvappend(array(int_col), int_col * 2)` — where operand widening makes the result element
// type BIGINT while `array(int_col)` still yields Integer cells — would otherwise throw
// `Integer cannot be cast to Long` when the array is materialized. Scalars are already
// pre-cast above; the runtime coercion is a no-op for them.
SqlTypeName targetType = elementType == null ? null : elementType.getSqlTypeName();
return Expressions.call(
Types.lookupMethod(MVAppendFunctionImpl.class, "mvappend", Object[].class),
Types.lookupMethod(
MVAppendFunctionImpl.class, "mvappendTyped", SqlTypeName.class, Object[].class),
Expressions.constant(targetType, SqlTypeName.class),
Expressions.newArrayInit(Object.class, coerced));
}
}

/** Codegen entry point: coerces flattened elements to {@code elementType}. */
public static Object mvappendTyped(SqlTypeName elementType, Object... args) {
return MVAppendCore.collectElements(elementType, args);
}

/** Untyped entry point used by unit tests; performs no element coercion. */
public static Object mvappend(Object... args) {
return MVAppendCore.collectElements(args);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,16 @@

package org.opensearch.sql.expression.function.CollectionUDF;

import static org.opensearch.sql.expression.function.BuiltinFunctionName.ADDFUNCTION;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.ARRAY_LENGTH;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.ARRAY_SLICE;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.IF;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.INTERNAL_ITEM;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.LESS;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.SUBTRACT;

import java.math.BigDecimal;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.opensearch.sql.expression.function.PPLFuncImpTable;

/**
Expand All @@ -37,6 +36,10 @@
* <li>Range access uses Calcite's ARRAY_SLICE operator (0-based indexing with length parameter)
* <li>Index conversion handles the difference between PPL's 0-based indexing and Calcite's
* conventions
* <li>Index arithmetic uses Calcite's raw {@code PLUS}/{@code MINUS} rather than PPL's widening
* {@code +}/{@code -} operators: array indices are int-domain and {@code ITEM}/{@code
* ARRAY_SLICE} require an INTEGER index, so the deliberate integer-overflow widening applied
* to user arithmetic must not leak into these internal, bounded computations.
* </ul>
*/
public class MVIndexFunctionImp implements PPLFuncImpTable.FunctionImp {
Expand All @@ -59,6 +62,16 @@ public RexNode resolve(RexBuilder builder, RexNode... args) {
}
}

/** Non-widening integer addition for internal, int-domain array-index math. */
private static RexNode add(RexBuilder builder, RexNode left, RexNode right) {
return builder.makeCall(SqlStdOperatorTable.PLUS, left, right);
}

/** Non-widening integer subtraction for internal, int-domain array-index math. */
private static RexNode subtract(RexBuilder builder, RexNode left, RexNode right) {
return builder.makeCall(SqlStdOperatorTable.MINUS, left, right);
}

/**
* Resolves single element access: mvindex(array, index)
*
Expand All @@ -72,11 +85,9 @@ private RexNode resolveSingleElement(
RexNode one = builder.makeExactLiteral(BigDecimal.ONE);

RexNode isNegative = PPLFuncImpTable.INSTANCE.resolve(builder, LESS, startIdx, zero);
RexNode sumArrayLenStart =
PPLFuncImpTable.INSTANCE.resolve(builder, ADDFUNCTION, arrayLen, startIdx);
RexNode negativeCase =
PPLFuncImpTable.INSTANCE.resolve(builder, ADDFUNCTION, sumArrayLenStart, one);
RexNode positiveCase = PPLFuncImpTable.INSTANCE.resolve(builder, ADDFUNCTION, startIdx, one);
RexNode sumArrayLenStart = add(builder, arrayLen, startIdx);
RexNode negativeCase = add(builder, sumArrayLenStart, one);
RexNode positiveCase = add(builder, startIdx, one);

RexNode normalizedStart =
PPLFuncImpTable.INSTANCE.resolve(builder, IF, isNegative, negativeCase, positiveCase);
Expand All @@ -97,21 +108,18 @@ private RexNode resolveRange(
RexNode one = builder.makeExactLiteral(BigDecimal.ONE);

RexNode isStartNegative = PPLFuncImpTable.INSTANCE.resolve(builder, LESS, startIdx, zero);
RexNode startNegativeCase =
PPLFuncImpTable.INSTANCE.resolve(builder, ADDFUNCTION, arrayLen, startIdx);
RexNode startNegativeCase = add(builder, arrayLen, startIdx);
RexNode normalizedStart =
PPLFuncImpTable.INSTANCE.resolve(builder, IF, isStartNegative, startNegativeCase, startIdx);

RexNode isEndNegative = PPLFuncImpTable.INSTANCE.resolve(builder, LESS, endIdx, zero);
RexNode endNegativeCase =
PPLFuncImpTable.INSTANCE.resolve(builder, ADDFUNCTION, arrayLen, endIdx);
RexNode endNegativeCase = add(builder, arrayLen, endIdx);
RexNode normalizedEnd =
PPLFuncImpTable.INSTANCE.resolve(builder, IF, isEndNegative, endNegativeCase, endIdx);

// Calculate length: (normalizedEnd - normalizedStart) + 1
RexNode diff =
PPLFuncImpTable.INSTANCE.resolve(builder, SUBTRACT, normalizedEnd, normalizedStart);
RexNode length = PPLFuncImpTable.INSTANCE.resolve(builder, ADDFUNCTION, diff, one);
RexNode diff = subtract(builder, normalizedEnd, normalizedStart);
RexNode length = add(builder, diff, one);

// Call ARRAY_SLICE(array, normalizedStart, length)
return PPLFuncImpTable.INSTANCE.resolve(builder, ARRAY_SLICE, array, normalizedStart, length);
Expand Down
Loading
Loading