package org.apache.spark.sql.comet;

import java.util.HashMap;
import org.apache.comet.CometRuntimeException;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.execution.ScalarSubquery;
import org.apache.spark.sql.types.Decimal;
import org.apache.spark.unsafe.types.UTF8String;

/* loaded from: input_file:org/apache/spark/sql/comet/CometScalarSubquery.class */
public class CometScalarSubquery {
    private static final HashMap<Long, HashMap<Long, ScalarSubquery>> subqueryMap = new HashMap<>();

    public static synchronized void setSubquery(long j, ScalarSubquery scalarSubquery) {
        if (!subqueryMap.containsKey(Long.valueOf(j))) {
            subqueryMap.put(Long.valueOf(j), new HashMap<>());
        }
        subqueryMap.get(Long.valueOf(j)).put(Long.valueOf(scalarSubquery.exprId().id()), scalarSubquery);
    }

    public static synchronized void removeSubquery(long j, ScalarSubquery scalarSubquery) {
        if (subqueryMap.containsKey(Long.valueOf(j))) {
            subqueryMap.get(Long.valueOf(j)).remove(Long.valueOf(scalarSubquery.exprId().id()));
            if (subqueryMap.get(Long.valueOf(j)).isEmpty()) {
                subqueryMap.remove(Long.valueOf(j));
            }
        }
    }

    private static Object getSubquery(Long l, Long l2) {
        if (subqueryMap.containsKey(l)) {
            return subqueryMap.get(l).get(l2).eval((InternalRow) null);
        }
        throw new CometRuntimeException("Subquery " + l2 + " not found for plan " + l + ".");
    }

    public static boolean isNull(long j, long j2) {
        return getSubquery(Long.valueOf(j), Long.valueOf(j2)) == null;
    }

    public static boolean getBoolean(long j, long j2) {
        return ((Boolean) getSubquery(Long.valueOf(j), Long.valueOf(j2))).booleanValue();
    }

    public static byte getByte(long j, long j2) {
        return ((Byte) getSubquery(Long.valueOf(j), Long.valueOf(j2))).byteValue();
    }

    public static short getShort(long j, long j2) {
        return ((Short) getSubquery(Long.valueOf(j), Long.valueOf(j2))).shortValue();
    }

    public static int getInt(long j, long j2) {
        return ((Integer) getSubquery(Long.valueOf(j), Long.valueOf(j2))).intValue();
    }

    public static long getLong(long j, long j2) {
        return ((Long) getSubquery(Long.valueOf(j), Long.valueOf(j2))).longValue();
    }

    public static float getFloat(long j, long j2) {
        return ((Float) getSubquery(Long.valueOf(j), Long.valueOf(j2))).floatValue();
    }

    public static double getDouble(long j, long j2) {
        return ((Double) getSubquery(Long.valueOf(j), Long.valueOf(j2))).doubleValue();
    }

    public static byte[] getDecimal(long j, long j2) {
        return ((Decimal) getSubquery(Long.valueOf(j), Long.valueOf(j2))).toJavaBigDecimal().unscaledValue().toByteArray();
    }

    public static String getString(long j, long j2) {
        return ((UTF8String) getSubquery(Long.valueOf(j), Long.valueOf(j2))).toString();
    }

    public static byte[] getBinary(long j, long j2) {
        return (byte[]) getSubquery(Long.valueOf(j), Long.valueOf(j2));
    }
}
