package kd.bos.flydb.core.interpreter.algox;

import java.time.LocalDateTime;
import java.time.ZoneId;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.IdentityHashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import kd.bos.algo.Input;
import kd.bos.algox.AlgoX;
import kd.bos.algox.DataSetX;
import kd.bos.algox.Grouper;
import kd.bos.algox.JobSession;
import kd.bos.algox.JoinDataSetX;
import kd.bos.dataentity.metadata.IDataEntityType;
import kd.bos.flydb.common.AlgoXOption;
import kd.bos.flydb.common.ServerOption;
import kd.bos.flydb.common.exception.ErrorCode;
import kd.bos.flydb.common.exception.Exceptions;
import kd.bos.flydb.core.Context;
import kd.bos.flydb.core.Contexts;
import kd.bos.flydb.core.Core;
import kd.bos.flydb.core.interpreter.BindNodeCompiler;
import kd.bos.flydb.core.interpreter.Executor;
import kd.bos.flydb.core.interpreter.ScalarEvaluationCompiler;
import kd.bos.flydb.core.interpreter.algox.AggregateFunction;
import kd.bos.flydb.core.interpreter.bind.BindableAggregate;
import kd.bos.flydb.core.interpreter.bind.BindableFilter;
import kd.bos.flydb.core.interpreter.bind.BindableJoin;
import kd.bos.flydb.core.interpreter.bind.BindableNode;
import kd.bos.flydb.core.interpreter.bind.BindableProject;
import kd.bos.flydb.core.interpreter.bind.BindableSort;
import kd.bos.flydb.core.interpreter.bind.BindableTableScan;
import kd.bos.flydb.core.interpreter.scalar.ScalarEvaluation;
import kd.bos.flydb.core.rel.Aggregate;
import kd.bos.flydb.core.rel.RelNode;
import kd.bos.flydb.core.rel.Sort;
import kd.bos.flydb.core.rex.RexCall;
import kd.bos.flydb.core.rex.RexInputRef;
import kd.bos.flydb.core.rex.RexLiteral;
import kd.bos.flydb.core.rex.RexNode;
import kd.bos.flydb.core.schema.FormAttribute;
import kd.bos.flydb.core.schema.Scanner;
import kd.bos.flydb.core.schema.cosmic.CosmicEntityTable;
import kd.bos.flydb.core.schema.cosmic.CosmicFormAttribute;
import kd.bos.flydb.core.schema.cosmic.IDataEntityTypeProvider;
import kd.bos.flydb.core.sql.tree.SqlJoinType;
import kd.bos.flydb.core.sql.tree.SqlKind;
import kd.bos.flydb.core.sql.type.DataType;
import kd.bos.flydb.core.sql.type.DataTypeField;
import kd.bos.trace.TraceSpan;
import kd.bos.trace.Tracer;
import kd.bos.xdb.XDBConfig;
import kd.bos.xdb.sharding.config.MainTableConfig;

/* loaded from: input_file:kd/bos/flydb/core/interpreter/algox/AlgoXBindNodeCompiler.class */
public class AlgoXBindNodeCompiler implements BindNodeCompiler {
    private final IdentityHashMap<BindableNode, DataSetX> sourceMap = new IdentityHashMap<>();
    private JobSession jobSession;
    private ScalarEvaluationCompiler scalarEvaluationCompiler;
    private DataSetX root;
    private Context context;

    private void initJobSession() {
        JobSession createSession = AlgoX.createSession(this.context.getConfig(AlgoXOption.JobName.key()));
        createSession.getContext().setRegion(this.context.getConfig(ServerOption.AlgoXRegion.key()));
        this.jobSession = createSession;
    }

    /* JADX WARN: Type inference failed for: r1v7, types: [java.time.ZonedDateTime] */
    @Override // kd.bos.flydb.core.interpreter.BindNodeCompiler
    public Executor compile(BindableNode bindableNode) {
        if (this.context == null) {
            this.context = Contexts.get();
        }
        if (this.jobSession == null) {
            initJobSession();
        }
        if (this.scalarEvaluationCompiler == null) {
            this.scalarEvaluationCompiler = new ScalarEvaluationCompiler(this.context);
        }
        convert(bindableNode);
        Objects.requireNonNull(this.root);
        DataSetOutput dataSetOutput = new DataSetOutput(CompilerHelp.convertRowType(bindableNode.getRowType()));
        DataSetOutputManager.register(dataSetOutput.getCursorId(), LocalDateTime.now().plusHours(Integer.parseInt(this.context.getConfig(ServerOption.AlgoXOutputMaxLifeTime.key()))).atZone(ZoneId.systemDefault()).toInstant().toEpochMilli());
        this.root.output(dataSetOutput);
        return () -> {
            TraceSpan create = Tracer.create(Core.TRACE_TYPE, "executeQuery");
            Throwable th = null;
            try {
                try {
                    this.jobSession.commit(Integer.parseInt(this.context.getConfig(ServerOption.QueryTimeout.key())), TimeUnit.SECONDS);
                    Executor.QueryResult queryResult = new Executor.QueryResult(dataSetOutput.getRowMeta(), dataSetOutput.getCursorId());
                    if (create != null) {
                        if (0 != 0) {
                            try {
                                create.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            create.close();
                        }
                    }
                    return queryResult;
                } finally {
                }
            } catch (Throwable th3) {
                if (create != null) {
                    if (th != null) {
                        try {
                            create.close();
                        } catch (Throwable th4) {
                            th.addSuppressed(th4);
                        }
                    } else {
                        create.close();
                    }
                }
                throw th3;
            }
        };
    }

    private void convert(BindableNode bindableNode) {
        if (!bindableNode.getInputList().isEmpty()) {
            Iterator<RelNode> it = bindableNode.getInputList().iterator();
            while (it.hasNext()) {
                convert((BindableNode) it.next());
            }
        }
        if (bindableNode instanceof BindableTableScan) {
            this.root = convertTableScan((BindableTableScan) bindableNode.cast(BindableTableScan.class));
            return;
        }
        if (bindableNode instanceof BindableSort) {
            this.root = convertSort((BindableSort) bindableNode.cast(BindableSort.class));
            return;
        }
        if (bindableNode instanceof BindableProject) {
            this.root = convertProject((BindableProject) bindableNode.cast(BindableProject.class));
            return;
        }
        if (bindableNode instanceof BindableJoin) {
            this.root = convertJoin((BindableJoin) bindableNode.cast(BindableJoin.class));
        } else if (bindableNode instanceof BindableFilter) {
            this.root = convertFilter((BindableFilter) bindableNode.cast(BindableFilter.class));
        } else if (bindableNode instanceof BindableAggregate) {
            this.root = convertAggregate1((BindableAggregate) bindableNode.cast(BindableAggregate.class));
        }
    }

    private DataSetX convertTableScan(BindableTableScan bindableTableScan) {
        if (Boolean.parseBoolean(getStringFromConfig(ServerOption.EnableShardingTableInput)) && XDBConfig.isXDBEnabled() && (bindableTableScan.table instanceof CosmicEntityTable)) {
            IDataEntityType load = IDataEntityTypeProvider.get().load(bindableTableScan.table.getName());
            if (load != null) {
                while (load.getParent() != null) {
                    load = load.getParent();
                }
                MainTableConfig configByEntity = XDBConfig.getShardingConfigProvider().getConfigByEntity(load.getName());
                if (configByEntity != null && configByEntity.isEnabled()) {
                    FormAttribute formAttribute = bindableTableScan.table.getFormAttribute();
                    if (formAttribute instanceof CosmicFormAttribute) {
                        Scanner[] createScanners = bindableTableScan.table.createScanners(bindableTableScan.index, bindableTableScan.filter, ((CosmicFormAttribute) formAttribute).getEntityType());
                        if (createScanners.length == 0) {
                            throw Exceptions.of(ErrorCode.Unexpected1, new Object[]{"scanners is empty"});
                        }
                        DataSetX fromInput = this.jobSession.fromInput((Input[]) ((List) Arrays.asList(createScanners).stream().map(scanner -> {
                            return new TableScanInput(scanner);
                        }).collect(Collectors.toList())).toArray(new TableScanInput[0]));
                        this.sourceMap.put(bindableTableScan, fromInput);
                        return fromInput;
                    }
                }
            }
        }
        DataSetX fromInput2 = this.jobSession.fromInput(new TableScanInput(bindableTableScan.table.createScanner(bindableTableScan.index, bindableTableScan.filter)));
        this.sourceMap.put(bindableTableScan, fromInput2);
        return fromInput2;
    }

    private String getStringFromConfig(ServerOption serverOption) {
        String config = Contexts.get().getConfig(serverOption.key());
        return config != null ? config : serverOption.defaultValue();
    }

    private DataSetX convertSort(BindableSort bindableSort) {
        DataSetX dataSetX = this.sourceMap.get((BindableNode) bindableSort.getInput(0).cast(BindableNode.class));
        DataSetX dataSetX2 = dataSetX;
        if (!bindableSort.sortItemList.isEmpty()) {
            boolean z = false;
            ArrayList arrayList = new ArrayList(bindableSort.sortItemList.size());
            for (Sort.SortItem sortItem : bindableSort.sortItemList) {
                if (!(sortItem.expression instanceof RexInputRef)) {
                    z = true;
                }
                arrayList.add(String.format("%s %s", dataSetX.getRowMeta().getFieldName(((RexInputRef) sortItem.expression.cast(RexInputRef.class)).getIndex()), sortItem.ordering.name().toLowerCase(Locale.getDefault())));
            }
            if (z) {
                throw Exceptions.of(ErrorCode.OrderWithComplexExpression, new Object[0]);
            }
            dataSetX2 = dataSetX.orderBy(new String[]{String.join(",", arrayList)});
        }
        if (bindableSort.limit != null || bindableSort.offset != null) {
            dataSetX2 = dataSetX2.reduceGroup(new LimitOffsetFunction(dataSetX2.getRowMeta(), bindableSort.offset != null ? (Integer) ((RexLiteral) bindableSort.offset.cast(RexLiteral.class)).getValue() : null, bindableSort.limit != null ? (Integer) ((RexLiteral) bindableSort.limit.cast(RexLiteral.class)).getValue() : null));
            dataSetX2.setSingleParallel(true);
        }
        this.sourceMap.put(bindableSort, dataSetX2);
        return dataSetX2;
    }

    private DataSetX convertProject(BindableProject bindableProject) {
        if (bindableProject.getInput(0) == null) {
            return convertOneRowScalarExpressionValue(bindableProject);
        }
        DataSetX dataSetX = this.sourceMap.get((BindableNode) bindableProject.getInput(0).cast(BindableNode.class));
        ArrayList arrayList = new ArrayList(bindableProject.exprList.size());
        Iterator<RexNode> it = bindableProject.exprList.iterator();
        while (it.hasNext()) {
            ScalarEvaluation compile = this.scalarEvaluationCompiler.compile(it.next());
            compile.setContext(this.context);
            arrayList.add(compile);
        }
        DataSetX map = dataSetX.map(new ProjectFunction(bindableProject.getRowType(), arrayList));
        this.sourceMap.put(bindableProject, map);
        return map;
    }

    private DataSetX convertOneRowScalarExpressionValue(BindableProject bindableProject) {
        ArrayList arrayList = new ArrayList(bindableProject.exprList.size());
        Iterator<RexNode> it = bindableProject.exprList.iterator();
        while (it.hasNext()) {
            ScalarEvaluation compile = this.scalarEvaluationCompiler.compile(it.next());
            compile.setContext(this.context);
            arrayList.add(compile);
        }
        Object[] objArr = new Object[arrayList.size()];
        for (int i = 0; i < arrayList.size(); i++) {
            objArr[i] = ((ScalarEvaluation) arrayList.get(i)).eval(new Object[0]);
        }
        DataSetX fromInput = this.jobSession.fromInput(new OneRowInput(CompilerHelp.convertRowType(bindableProject.getRowType()), objArr));
        this.sourceMap.put(bindableProject, fromInput);
        return fromInput;
    }

    private void unzipJoinCondition(RexNode rexNode, List<RexNode> list) {
        if (rexNode == null) {
            throw Exceptions.of(ErrorCode.UnsupportedSqlJoinCondition1, new Object[0]);
        }
        if (rexNode.getKind() == SqlKind.OR) {
            throw Exceptions.of(ErrorCode.OnlySupportedAndOperatorEquiJoin, new Object[0]);
        }
        if (rexNode.getKind() == SqlKind.AND) {
            RexCall rexCall = (RexCall) rexNode.cast(RexCall.class);
            unzipJoinCondition(rexCall.getOperand(0), list);
            unzipJoinCondition(rexCall.getOperand(1), list);
        } else {
            if (!(rexNode instanceof RexCall)) {
                throw Exceptions.of(ErrorCode.UnsupportedSqlJoinCondition, new Object[]{rexNode.toString()});
            }
            RexCall rexCall2 = (RexCall) rexNode.cast(RexCall.class);
            if (rexCall2.getOperator().getKind() != SqlKind.EQUALS) {
                throw Exceptions.of(ErrorCode.OnlySupportedAndOperatorEquiJoin, new Object[0]);
            }
            if (!(rexCall2.getOperand(0) instanceof RexInputRef) || !(rexCall2.getOperand(1) instanceof RexInputRef)) {
                throw Exceptions.of(ErrorCode.OnlySupportedAndOperatorEquiJoin, new Object[0]);
            }
            list.add(rexCall2);
        }
    }

    private DataSetX convertJoin(BindableJoin bindableJoin) {
        JoinDataSetX fullJoin;
        BindableNode bindableNode = (BindableNode) bindableJoin.getInput(0).cast(BindableNode.class);
        BindableNode bindableNode2 = (BindableNode) bindableJoin.getInput(1).cast(BindableNode.class);
        DataSetX dataSetX = this.sourceMap.get(bindableNode);
        DataSetX dataSetX2 = this.sourceMap.get(bindableNode2);
        HashSet hashSet = new HashSet(bindableNode.getRowType().getFieldCount() + bindableNode2.getRowType().getFieldCount());
        int fieldCount = bindableNode.getRowType().getFieldCount();
        hashSet.addAll(Arrays.asList(dataSetX.getRowMeta().getFieldNames()));
        String[] strArr = new String[dataSetX2.getRowMeta().getFieldCount()];
        boolean z = false;
        for (int i = 0; i < dataSetX2.getRowMeta().getFieldNames().length; i++) {
            String fieldName = dataSetX2.getRowMeta().getFieldName(i);
            if (hashSet.contains(fieldName)) {
                z = true;
                fieldName = fieldName + '$' + fieldCount + i;
            }
            strArr[i] = fieldName;
        }
        if (z) {
            dataSetX2 = dataSetX2.map(new RenameFunction(dataSetX2.getRowMeta(), strArr));
        }
        ArrayList arrayList = new ArrayList();
        unzipJoinCondition(bindableJoin.condition, arrayList);
        if (arrayList.isEmpty()) {
            throw Exceptions.of(ErrorCode.UnsupportedSqlJoinCondition1, new Object[0]);
        }
        if (bindableJoin.joinType == SqlJoinType.CROSS) {
            throw Exceptions.of(ErrorCode.UnsupportedFeature, new Object[]{"CROSS JOIN"});
        }
        switch (bindableJoin.joinType) {
            case INNER:
                fullJoin = dataSetX.join(dataSetX2);
                break;
            case LEFT:
                fullJoin = dataSetX.leftJoin(dataSetX2);
                break;
            case RIGHT:
                fullJoin = dataSetX.rightJoin(dataSetX2);
                break;
            case FULL:
                fullJoin = dataSetX.fullJoin(dataSetX2);
                break;
            default:
                throw Exceptions.of(ErrorCode.UnsupportedKeyword, new Object[]{bindableJoin.joinType});
        }
        Iterator<RexNode> it = arrayList.iterator();
        while (it.hasNext()) {
            RexCall rexCall = (RexCall) it.next().cast(RexCall.class);
            int index = ((RexInputRef) rexCall.getOperand(0).cast(RexInputRef.class)).getIndex();
            int index2 = ((RexInputRef) rexCall.getOperand(1).cast(RexInputRef.class)).getIndex();
            if (index < fieldCount && index2 < fieldCount) {
                throw Exceptions.of(ErrorCode.UnsupportedSqlJoinCondition, new Object[]{bindableJoin.condition.toString()});
            }
            if (index >= fieldCount && index2 >= fieldCount) {
                throw Exceptions.of(ErrorCode.UnsupportedSqlJoinCondition, new Object[]{bindableJoin.condition.toString()});
            }
            if (index > index2) {
                index = index2;
                index2 = index;
            }
            fullJoin = fullJoin.on(CompilerHelp.getFieldName(dataSetX, index), CompilerHelp.getFieldName(dataSetX2, index2 - dataSetX.getRowMeta().getFieldCount()));
        }
        this.sourceMap.put(bindableJoin, fullJoin);
        return fullJoin;
    }

    private DataSetX convertFilter(BindableFilter bindableFilter) {
        DataSetX dataSetX = this.sourceMap.get((BindableNode) bindableFilter.getInput(0).cast(BindableNode.class));
        ScalarEvaluation compile = this.scalarEvaluationCompiler.compile(bindableFilter.condition);
        compile.setContext(this.context);
        DataSetX filter = dataSetX.filter(new FilterFunction(compile));
        this.sourceMap.put(bindableFilter, filter);
        return filter;
    }

    private DataSetX convertAggregate1(BindableAggregate bindableAggregate) {
        DataSetX dataSetX = this.sourceMap.get((BindableNode) bindableAggregate.getInput(0).cast(BindableNode.class));
        DataType rowType = bindableAggregate.getRowType();
        List<DataTypeField> fieldList = rowType.getFieldList();
        String[] strArr = new String[bindableAggregate.groupList.size()];
        for (int i = 0; i < bindableAggregate.groupList.size(); i++) {
            strArr[i] = fieldList.get(bindableAggregate.groupList.get(i).intValue()).getName();
        }
        Grouper groupBy = bindableAggregate.groupList.isEmpty() ? null : dataSetX.groupBy(strArr);
        ArrayList arrayList = new ArrayList(bindableAggregate.aggCallList.size());
        for (Aggregate.AggCall aggCall : bindableAggregate.aggCallList) {
            arrayList.add(new AggregateFunction.AggCall(aggCall.operator.getKind(), aggCall.index, aggCall.distinct, aggCall.ignoreNull, aggCall.type));
        }
        DataSetX reduceGroup = groupBy == null ? dataSetX.reduceGroup(new AggregateFunction(CompilerHelp.convertRowType(rowType), bindableAggregate.groupList, arrayList)) : groupBy.reduceGroup(new AggregateFunction(CompilerHelp.convertRowType(rowType), bindableAggregate.groupList, arrayList));
        this.sourceMap.put(bindableAggregate, reduceGroup);
        return reduceGroup;
    }
}
