package dev.langchain4j.agent.tool;

import dev.langchain4j.internal.Json;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.util.Map;
import java.util.Objects;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:dev/langchain4j/agent/tool/DefaultToolExecutor.class */
public class DefaultToolExecutor implements ToolExecutor {
    private static final Logger log = LoggerFactory.getLogger(DefaultToolExecutor.class);
    private final Object object;
    private final Method method;

    public DefaultToolExecutor(Object obj, Method method) {
        this.object = Objects.requireNonNull(obj, "object");
        this.method = (Method) Objects.requireNonNull(method, "method");
    }

    @Override // dev.langchain4j.agent.tool.ToolExecutor
    public String execute(ToolExecutionRequest toolExecutionRequest, Object obj) {
        log.debug("About to execute {} for memoryId {}", toolExecutionRequest, obj);
        Object[] prepareArguments = prepareArguments(this.method, ToolExecutionRequestUtil.argumentsAsMap(toolExecutionRequest.arguments()), obj);
        try {
            String execute = execute(prepareArguments);
            log.debug("Tool execution result: {}", execute);
            return execute;
        } catch (IllegalAccessException e) {
            try {
                this.method.setAccessible(true);
                String execute2 = execute(prepareArguments);
                log.debug("Tool execution result: {}", execute2);
                return execute2;
            } catch (IllegalAccessException e2) {
                throw new RuntimeException(e2);
            } catch (InvocationTargetException e3) {
                Throwable cause = e3.getCause();
                log.error("Error while executing tool", cause);
                return cause.getMessage();
            }
        } catch (InvocationTargetException e4) {
            Throwable cause2 = e4.getCause();
            log.error("Error while executing tool", cause2);
            return cause2.getMessage();
        }
    }

    private String execute(Object[] objArr) throws IllegalAccessException, InvocationTargetException {
        Object invoke = this.method.invoke(this.object, objArr);
        Class<?> returnType = this.method.getReturnType();
        return returnType == Void.TYPE ? "Success" : returnType == String.class ? (String) invoke : Json.toJson(invoke);
    }

    static Object[] prepareArguments(Method method, Map<String, Object> map, Object obj) {
        Parameter[] parameters = method.getParameters();
        Object[] objArr = new Object[parameters.length];
        for (int i = 0; i < parameters.length; i++) {
            if (parameters[i].isAnnotationPresent(ToolMemoryId.class)) {
                objArr[i] = obj;
            } else {
                String name = parameters[i].getName();
                if (map.containsKey(name)) {
                    objArr[i] = coerceArgument(map.get(name), name, parameters[i].getType());
                }
            }
        }
        return objArr;
    }

    static Object coerceArgument(Object obj, String str, Class<?> cls) {
        if (cls == String.class) {
            return obj.toString();
        }
        if (cls.isEnum()) {
            try {
                return Enum.valueOf(cls, (String) Objects.requireNonNull(obj.toString()));
            } catch (Error | Exception e) {
                throw new IllegalArgumentException(String.format("Argument \"%s\" is not a valid enum value for %s: <%s>", str, cls.getName(), obj), e);
            }
        }
        if (cls == Boolean.class || cls == Boolean.TYPE) {
            if (obj instanceof Boolean) {
                return obj;
            }
            throw new IllegalArgumentException(String.format("Argument \"%s\" is not convertable to %s, got %s: <%s>", str, cls.getName(), obj.getClass().getName(), obj));
        }
        if (cls == Double.class || cls == Double.TYPE) {
            return Double.valueOf(getDoubleValue(obj, str, cls));
        }
        if (cls != Float.class && cls != Float.TYPE) {
            return cls == BigDecimal.class ? BigDecimal.valueOf(getDoubleValue(obj, str, cls)) : (cls == Integer.class || cls == Integer.TYPE) ? Integer.valueOf((int) getBoundedLongValue(obj, str, cls, -2147483648L, 2147483647L)) : (cls == Long.class || cls == Long.TYPE) ? Long.valueOf(getBoundedLongValue(obj, str, cls, Long.MIN_VALUE, Long.MAX_VALUE)) : (cls == Short.class || cls == Short.TYPE) ? Short.valueOf((short) getBoundedLongValue(obj, str, cls, -32768L, 32767L)) : (cls == Byte.class || cls == Byte.TYPE) ? Byte.valueOf((byte) getBoundedLongValue(obj, str, cls, -128L, 127L)) : cls == BigInteger.class ? BigDecimal.valueOf(getNonFractionalDoubleValue(obj, str, cls)).toBigInteger() : obj;
        }
        double doubleValue = getDoubleValue(obj, str, cls);
        checkBounds(doubleValue, str, cls, -1.401298464324817E-45d, 3.4028234663852886E38d);
        return Float.valueOf((float) doubleValue);
    }

    private static double getDoubleValue(Object obj, String str, Class<?> cls) {
        if (obj instanceof Number) {
            return ((Number) obj).doubleValue();
        }
        throw new IllegalArgumentException(String.format("Argument \"%s\" is not convertable to %s, got %s: <%s>", str, cls.getName(), obj.getClass().getName(), obj));
    }

    private static double getNonFractionalDoubleValue(Object obj, String str, Class<?> cls) {
        double doubleValue = getDoubleValue(obj, str, cls);
        if (hasNoFractionalPart(Double.valueOf(doubleValue))) {
            return doubleValue;
        }
        throw new IllegalArgumentException(String.format("Argument \"%s\" has non-integer value for %s: <%s>", str, cls.getName(), obj));
    }

    private static void checkBounds(double d, String str, Class<?> cls, double d2, double d3) {
        if (d < d2 || d > d3) {
            throw new IllegalArgumentException(String.format("Argument \"%s\" is out of range for %s: <%s>", str, cls.getName(), Double.valueOf(d)));
        }
    }

    private static long getBoundedLongValue(Object obj, String str, Class<?> cls, long j, long j2) {
        double nonFractionalDoubleValue = getNonFractionalDoubleValue(obj, str, cls);
        checkBounds(nonFractionalDoubleValue, str, cls, j, j2);
        return (long) nonFractionalDoubleValue;
    }

    static boolean hasNoFractionalPart(Double d) {
        return d.equals(Double.valueOf(Math.floor(d.doubleValue())));
    }
}
