Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
import org.opensearch.sql.ast.tree.StreamWindow;
import org.opensearch.sql.ast.tree.SubqueryAlias;
import org.opensearch.sql.ast.tree.TableFunction;
import org.opensearch.sql.ast.tree.Timewrap;
import org.opensearch.sql.ast.tree.Transpose;
import org.opensearch.sql.ast.tree.Trendline;
import org.opensearch.sql.ast.tree.Union;
Expand Down Expand Up @@ -301,6 +302,10 @@ public T visitChart(Chart node, C context) {
return visitChildren(node, context);
}

public T visitTimewrap(Timewrap node, C context) {
return visitChildren(node, context);
}

public T visitRegex(Regex node, C context) {
return visitChildren(node, context);
}
Expand Down
48 changes: 48 additions & 0 deletions core/src/main/java/org/opensearch/sql/ast/tree/Timewrap.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.ast.tree;

import com.google.common.collect.ImmutableList;
import java.util.List;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.ToString;
import org.opensearch.sql.ast.AbstractNodeVisitor;
import org.opensearch.sql.ast.expression.Literal;
import org.opensearch.sql.ast.expression.SpanUnit;

/** AST node representing the timewrap command. */
@Getter
@ToString
@EqualsAndHashCode(callSuper = false)
@RequiredArgsConstructor
public class Timewrap extends UnresolvedPlan {
private final SpanUnit unit;
private final int value;
private final String align; // "end" or "now"
private final String series; // "relative", "short", or "exact"
private final String timeFormat; // format string for series=exact, nullable
private final Literal spanLiteral; // original span literal for display

private UnresolvedPlan child;

@Override
public UnresolvedPlan attach(UnresolvedPlan child) {
this.child = child;
return this;
}

@Override
public List<UnresolvedPlan> getChild() {
return this.child == null ? ImmutableList.of() : ImmutableList.of(this.child);
}

@Override
public <T, C> T accept(AbstractNodeVisitor<T, C> nodeVisitor, C context) {
return nodeVisitor.visitTimewrap(this, context);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,18 @@ public class CalcitePlanContext {
/** This thread local variable is only used to skip script encoding in script pushdown. */
public static final ThreadLocal<Boolean> skipEncoding = ThreadLocal.withInitial(() -> false);

/** When true, the execution engine strips all-null columns from the result (used by timewrap). */
public static final ThreadLocal<Boolean> stripNullColumns = ThreadLocal.withInitial(() -> false);

/**
* Timewrap span unit name for column renaming in the execution engine. When set, the execution
* engine uses __base_offset__ to compute absolute period names (e.g., "501days_before").
*/
public static final ThreadLocal<String> timewrapUnitName = new ThreadLocal<>();

/** Timewrap series mode: "relative", "short", or "exact". */
public static final ThreadLocal<String> timewrapSeries = new ThreadLocal<>();

/** Thread-local switch that tells whether the current query prefers legacy behavior. */
private static final ThreadLocal<Boolean> legacyPreferredFlag =
ThreadLocal.withInitial(() -> true);
Expand Down Expand Up @@ -169,6 +181,7 @@ public static void run(Runnable action, Settings settings) {
action.run();
} finally {
legacyPreferredFlag.remove();
clearTimewrapSignals();
}
}

Expand All @@ -179,6 +192,17 @@ public static boolean isLegacyPreferred() {
return legacyPreferredFlag.get();
}

/**
* Resets the timewrap thread-locals set by {@code CalciteRelNodeVisitor.visitTimewrap}. Called
* from the query lifecycle's {@code finally} on every path (execute, explain, and exceptions) so
* the signals never leak onto the next query that reuses this pooled worker thread.
*/
public static void clearTimewrapSignals() {
stripNullColumns.set(false);
timewrapUnitName.set(null);
timewrapSeries.set(null);
}

public void putRexLambdaRefMap(Map<String, RexLambdaRef> candidateMap) {
this.rexLambdaRefMap.putAll(candidateMap);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import com.google.common.collect.Streams;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
Expand Down Expand Up @@ -160,6 +161,7 @@
import org.opensearch.sql.ast.tree.StreamWindow;
import org.opensearch.sql.ast.tree.SubqueryAlias;
import org.opensearch.sql.ast.tree.TableFunction;
import org.opensearch.sql.ast.tree.Timewrap;
import org.opensearch.sql.ast.tree.Trendline;
import org.opensearch.sql.ast.tree.Trendline.TrendlineType;
import org.opensearch.sql.ast.tree.Union;
Expand All @@ -176,6 +178,7 @@
import org.opensearch.sql.calcite.utils.JoinAndLookupUtils;
import org.opensearch.sql.calcite.utils.PPLHintUtils;
import org.opensearch.sql.calcite.utils.PlanUtils;
import org.opensearch.sql.calcite.utils.TimewrapUtils;
import org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils;
import org.opensearch.sql.calcite.utils.WildcardUtils;
import org.opensearch.sql.common.error.ErrorCode;
Expand All @@ -187,6 +190,7 @@
import org.opensearch.sql.exception.SemanticCheckException;
import org.opensearch.sql.expression.HighlightExpression;
import org.opensearch.sql.expression.function.BuiltinFunctionName;
import org.opensearch.sql.expression.function.PPLBuiltinOperators;
import org.opensearch.sql.expression.function.PPLFuncImpTable;
import org.opensearch.sql.expression.parse.RegexCommonUtils;
import org.opensearch.sql.utils.ParseUtils;
Expand Down Expand Up @@ -3823,6 +3827,148 @@ public RelNode visitChart(Chart node, CalcitePlanContext context) {
return relBuilder.peek();
}

@Override
public RelNode visitTimewrap(Timewrap node, CalcitePlanContext context) {
visitChildren(node, context);

// Signal the execution engine to strip all-null columns and rename with absolute offsets
CalcitePlanContext.stripNullColumns.set(true);

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a leak-guard test: a timewrap query followed by a non-timewrap query on the same pooled thread, asserting the second result has no base_offset/period artifacts?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added TimewrapSignalsLeakTest in 41f00ee

CalcitePlanContext.timewrapUnitName.set(
TimewrapUtils.unitBaseName(node.getUnit(), node.getValue()) + "|_before");
CalcitePlanContext.timewrapSeries.set(node.getSeries());

RelBuilder b = context.relBuilder;
RexBuilder rx = context.rexBuilder;

List<String> fieldNames =
b.peek().getRowType().getFieldNames().stream().filter(f -> !isMetadataField(f)).toList();
String tsFieldName = fieldNames.get(0);
List<String> valueFieldNames = fieldNames.subList(1, fieldNames.size());

boolean variableLength = TimewrapUtils.isVariableLengthUnit(node.getUnit());
RelDataType bigintType = rx.getTypeFactory().createSqlType(SqlTypeName.BIGINT);

RexNode periodNum;
RexNode displayTimestamp;
RexNode baseOffset;

if (variableLength) {
// --- Variable-length units (month, quarter, year): EXTRACT-based calendar arithmetic ---
RexNode tsField = b.field(tsFieldName);
RexNode tsUnitNum =
TimewrapUtils.calendarUnitNumber(rx, tsField, node.getUnit(), node.getValue());

b.projectPlus(b.aggregateCall(SqlStdOperatorTable.MAX, tsField).over().as("__max_ts__"));
RexNode maxTs = b.field("__max_ts__");
RexNode maxUnitNum =
TimewrapUtils.calendarUnitNumber(rx, maxTs, node.getUnit(), node.getValue());

periodNum =
rx.makeCall(
SqlStdOperatorTable.PLUS,
rx.makeCall(SqlStdOperatorTable.MINUS, maxUnitNum, tsUnitNum),
rx.makeExactLiteral(BigDecimal.ONE, bigintType));

RexNode tsEpoch =
rx.makeCast(bigintType, rx.makeCall(PPLBuiltinOperators.UNIX_TIMESTAMP, tsField), true);
RexNode unitStartEpoch = TimewrapUtils.calendarUnitStartEpoch(rx, tsField, node.getUnit());
RexNode offsetSec = rx.makeCall(SqlStdOperatorTable.MINUS, tsEpoch, unitStartEpoch);
RexNode maxUnitStartEpoch = TimewrapUtils.calendarUnitStartEpoch(rx, maxTs, node.getUnit());
RexNode displayEpoch = rx.makeCall(SqlStdOperatorTable.PLUS, maxUnitStartEpoch, offsetSec);
displayTimestamp = rx.makeCall(PPLBuiltinOperators.FROM_UNIXTIME, displayEpoch);

long nowEpochSec = context.functionProperties.getQueryStartClock().millis() / 1000;
Long referenceEpoch = null;
if ("end".equals(node.getAlign())) {
referenceEpoch = TimewrapUtils.extractTimestampUpperBound(node);
}
if (referenceEpoch == null) {
referenceEpoch = nowEpochSec;
}
long refUnitNum =
TimewrapUtils.calendarUnitNumberFromEpoch(
referenceEpoch, node.getUnit(), node.getValue());
RexNode refUnitNumLit = rx.makeBigintLiteral(BigDecimal.valueOf(refUnitNum));
baseOffset = rx.makeCall(SqlStdOperatorTable.MINUS, refUnitNumLit, maxUnitNum);

} else {
// --- Fixed-length units (sec, min, hr, day, week): epoch-based arithmetic ---
long spanSec = TimewrapUtils.spanToSeconds(node.getUnit(), node.getValue());

RexNode tsEpochExpr =
rx.makeCast(
bigintType,
rx.makeCall(PPLBuiltinOperators.UNIX_TIMESTAMP, b.field(tsFieldName)),
true);
b.projectPlus(
b.alias(tsEpochExpr, "__ts_epoch__"),
b.aggregateCall(SqlStdOperatorTable.MAX, tsEpochExpr).over().as("__max_epoch__"));

RexNode tsEpoch = b.field("__ts_epoch__");
RexNode maxEpoch = b.field("__max_epoch__");
RexNode spanLit = rx.makeBigintLiteral(BigDecimal.valueOf(spanSec));

RexNode diff = rx.makeCall(SqlStdOperatorTable.MINUS, maxEpoch, tsEpoch);
periodNum =
rx.makeCall(
SqlStdOperatorTable.PLUS,
rx.makeCall(SqlStdOperatorTable.DIVIDE, diff, spanLit),
rx.makeExactLiteral(BigDecimal.ONE, bigintType));

RexNode offsetSec = rx.makeCall(SqlStdOperatorTable.MOD, tsEpoch, spanLit);
RexNode latestPeriodStart =
rx.makeCall(
SqlStdOperatorTable.MINUS,
maxEpoch,
rx.makeCall(SqlStdOperatorTable.MOD, maxEpoch, spanLit));
RexNode displayEpoch = rx.makeCall(SqlStdOperatorTable.PLUS, latestPeriodStart, offsetSec);
displayTimestamp = rx.makeCall(PPLBuiltinOperators.FROM_UNIXTIME, displayEpoch);

long nowEpochSec = context.functionProperties.getQueryStartClock().millis() / 1000;
Long referenceEpoch = null;
if ("end".equals(node.getAlign())) {
referenceEpoch = TimewrapUtils.extractTimestampUpperBound(node);
}
if (referenceEpoch == null) {
referenceEpoch = nowEpochSec;
}
RexNode refLit = rx.makeBigintLiteral(BigDecimal.valueOf(referenceEpoch));
// Floor-divide (ref - maxEpoch) by span: integer DIVIDE truncates toward zero, which is wrong
// when the reference is below maxEpoch (e.g. align=now over future-dated data) — it would
// shift period labels by one across the latest/before/after boundary. Cast to DOUBLE and
// FLOOR
// to get true floor division, then back to BIGINT.
RelDataType doubleType = rx.getTypeFactory().createSqlType(SqlTypeName.DOUBLE);
RexNode refDiff = rx.makeCall(SqlStdOperatorTable.MINUS, refLit, maxEpoch);
RexNode refDiffDouble = rx.makeCast(doubleType, refDiff, true);
baseOffset =
rx.makeCast(
bigintType,
rx.makeCall(
SqlStdOperatorTable.FLOOR,
rx.makeCall(SqlStdOperatorTable.DIVIDE, refDiffDouble, spanLit)),
true);
}

// Step 3: Project [display_timestamp, value_columns..., base_offset, period]
// base_offset is included in the group key so it survives the PIVOT
List<RexNode> projections = new ArrayList<>();
projections.add(b.alias(displayTimestamp, tsFieldName));
for (String vf : valueFieldNames) {
projections.add(b.field(vf));
}
projections.add(b.alias(baseOffset, "__base_offset__"));
projections.add(b.alias(periodNum, "__period__"));
b.project(projections);

// Step 4: Sort by offset, then period (execution engine will pivot)
// No Calcite PIVOT -- the execution engine pivots dynamically after reading all rows.
// Output schema: [display_timestamp, value_columns..., __base_offset__, __period__]
b.sort(b.field(tsFieldName), b.field("__period__"));

return b.peek();
}

/**
* Aggregate by column split then rank by grand total (summed value of each category). The output
* is <code>[col-split, grand-total, row-number]</code>
Expand Down
Loading
Loading