package io.quarkiverse.langchain4j.deployment;

import dev.langchain4j.model.input.structured.StructuredPromptProcessor;
import io.quarkiverse.langchain4j.runtime.StructuredPromptsRecorder;
import io.quarkiverse.langchain4j.runtime.prompt.Mappable;
import io.quarkus.deployment.annotations.BuildProducer;
import io.quarkus.deployment.annotations.BuildStep;
import io.quarkus.deployment.annotations.ExecutionTime;
import io.quarkus.deployment.annotations.Record;
import io.quarkus.deployment.builditem.BytecodeTransformerBuildItem;
import io.quarkus.deployment.builditem.CombinedIndexBuildItem;
import io.quarkus.gizmo.ClassTransformer;
import io.quarkus.gizmo.MethodCreator;
import io.quarkus.gizmo.MethodDescriptor;
import io.quarkus.gizmo.ResultHandle;
import java.io.IOException;
import java.io.InputStream;
import java.io.UncheckedIOException;
import java.lang.reflect.Modifier;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.BiFunction;
import java.util.stream.Collectors;
import org.jboss.jandex.AnnotationInstance;
import org.jboss.jandex.AnnotationTarget;
import org.jboss.jandex.AnnotationValue;
import org.jboss.jandex.ClassInfo;
import org.jboss.jandex.DotName;
import org.jboss.jandex.FieldInfo;
import org.jboss.jandex.IndexView;
import org.jboss.logging.Logger;
import org.objectweb.asm.ClassReader;
import org.objectweb.asm.ClassVisitor;
import org.objectweb.asm.Type;
import org.objectweb.asm.tree.AbstractInsnNode;
import org.objectweb.asm.tree.ClassNode;
import org.objectweb.asm.tree.MethodInsnNode;
import org.objectweb.asm.tree.MethodNode;
import org.objectweb.asm.tree.analysis.Analyzer;
import org.objectweb.asm.tree.analysis.AnalyzerException;
import org.objectweb.asm.tree.analysis.BasicValue;
import org.objectweb.asm.tree.analysis.SimpleVerifier;
import org.objectweb.asm.tree.analysis.Value;

/* loaded from: input_file:io/quarkiverse/langchain4j/deployment/PromptProcessor.class */
public class PromptProcessor {
    private static final Logger log = Logger.getLogger(AiServicesProcessor.class);
    public static final MethodDescriptor MAP_PUT = MethodDescriptor.ofMethod(Map.class, "put", Object.class, new Class[]{Object.class, Object.class});
    public static final MethodDescriptor MAP_PUT_ALL = MethodDescriptor.ofMethod(Map.class, "putAll", Void.TYPE, new Class[]{Map.class});
    private static final String STRUCTURED_PROMPT_PROCESSOR_BINARY_NAME = StructuredPromptProcessor.class.getName().replace(".", "/");
    private static final String TO_PROMPT = "toPrompt";
    private static final String TO_PROMPT_DESCRIPTOR = "(Ljava/lang/Object;)Ldev/langchain4j/model/input/Prompt;";

    /* loaded from: input_file:io/quarkiverse/langchain4j/deployment/PromptProcessor$StructuredPromptAnnotatedTransformer.class */
    private static class StructuredPromptAnnotatedTransformer implements BiFunction<String, ClassVisitor, ClassVisitor> {
        private final ClassInfo annotatedClass;
        private final boolean hasSuperMappable;
        private final String superClassName;

        private StructuredPromptAnnotatedTransformer(ClassInfo classInfo, boolean z, String str) {
            this.annotatedClass = classInfo;
            this.hasSuperMappable = z;
            this.superClassName = str;
        }

        @Override // java.util.function.BiFunction
        public ClassVisitor apply(String str, ClassVisitor classVisitor) {
            ClassTransformer classTransformer = new ClassTransformer(this.annotatedClass.name().toString());
            classTransformer.addInterface(Mappable.class);
            MethodCreator addMethod = classTransformer.addMethod("obtainFieldValuesMap", Map.class, new Object[0]);
            ResultHandle newInstance = addMethod.newInstance(MethodDescriptor.ofConstructor(HashMap.class, new Class[0]), new ResultHandle[0]);
            for (FieldInfo fieldInfo : this.annotatedClass.fields()) {
                short flags = fieldInfo.flags();
                if (!Modifier.isStatic(flags) && !Modifier.isTransient(flags)) {
                    String name = fieldInfo.name();
                    addMethod.invokeInterfaceMethod(PromptProcessor.MAP_PUT, newInstance, new ResultHandle[]{addMethod.load(name), addMethod.readInstanceField(fieldInfo, addMethod.getThis())});
                }
            }
            if (this.hasSuperMappable) {
                ResultHandle invokeSpecialMethod = addMethod.invokeSpecialMethod(MethodDescriptor.ofMethod(this.superClassName, "obtainFieldValuesMap", Map.class, new Object[0]), addMethod.getThis(), new ResultHandle[0]);
                addMethod.invokeInterfaceMethod(PromptProcessor.MAP_PUT_ALL, invokeSpecialMethod, new ResultHandle[]{newInstance});
                addMethod.returnValue(invokeSpecialMethod);
            } else {
                addMethod.returnValue(newInstance);
            }
            return classTransformer.applyTo(classVisitor);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/quarkiverse/langchain4j/deployment/PromptProcessor$UnionValue.class */
    public static class UnionValue extends BasicValue {
        private final Set<Type> union;

        public static BasicValue create(BasicValue basicValue) {
            if (basicValue == null) {
                return null;
            }
            return basicValue.getType() == null ? new UnionValue(null, Set.of()) : new UnionValue(basicValue.getType(), Set.of(basicValue.getType()));
        }

        public static BasicValue create(BasicValue basicValue, BasicValue basicValue2, BasicValue basicValue3) {
            HashSet hashSet = new HashSet();
            hashSet.addAll(((UnionValue) basicValue2).union);
            hashSet.addAll(((UnionValue) basicValue3).union);
            return new UnionValue(basicValue.getType(), Set.copyOf(hashSet));
        }

        private UnionValue(Type type, Set<Type> set) {
            super(type);
            this.union = (Set) Objects.requireNonNull(set);
        }

        public String toString() {
            return super.toString() + " | union of " + this.union;
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if ((obj instanceof UnionValue) && super.equals(obj)) {
                return Objects.equals(this.union, ((UnionValue) obj).union);
            }
            return false;
        }

        public int hashCode() {
            return Objects.hash(Integer.valueOf(super.hashCode()), this.union);
        }
    }

    @BuildStep
    @Record(ExecutionTime.STATIC_INIT)
    public void structuredPromptSupport(StructuredPromptsRecorder structuredPromptsRecorder, CombinedIndexBuildItem combinedIndexBuildItem, BuildProducer<BytecodeTransformerBuildItem> buildProducer) {
        IndexView index = combinedIndexBuildItem.getIndex();
        for (AnnotationInstance annotationInstance : index.getAnnotations(Langchain4jDotNames.STRUCTURED_PROMPT)) {
            AnnotationTarget target = annotationInstance.target();
            if (target.kind() == AnnotationTarget.Kind.CLASS) {
                String[] asStringArray = annotationInstance.value().asStringArray();
                AnnotationValue value = annotationInstance.value("delimiter");
                String join = String.join(value != null ? value.asString() : "\n", asStringArray);
                ClassInfo asClass = target.asClass();
                if (!hasNestedParams(join)) {
                    ClassInfo classInfo = asClass;
                    while (true) {
                        ClassInfo classInfo2 = classInfo;
                        DotName superName = classInfo2.superName();
                        ClassInfo classByName = DotNames.OBJECT.equals(superName) ? null : index.getClassByName(superName);
                        buildProducer.produce(new BytecodeTransformerBuildItem(classInfo2.name().toString(), new StructuredPromptAnnotatedTransformer(classInfo2, classByName != null, superName.toString())));
                        if (classByName == null) {
                            break;
                        } else {
                            classInfo = classByName;
                        }
                    }
                }
                structuredPromptsRecorder.add(asClass.name().toString(), join);
            }
        }
        warnForUnsafeUsage(index);
    }

    private static boolean hasNestedParams(String str) {
        return TemplateUtil.parts(str).stream().anyMatch(list -> {
            return list.size() > 1;
        });
    }

    private void warnForUnsafeUsage(IndexView indexView) {
        HashSet hashSet = new HashSet();
        Iterator it = indexView.getKnownUsers(Langchain4jDotNames.STRUCTURED_PROMPT_PROCESSOR).iterator();
        while (it.hasNext()) {
            String dotName = ((ClassInfo) it.next()).name().toString();
            if (!dotName.startsWith("io.quarkiverse.langchain4j") && !dotName.startsWith("dev.langchain4j")) {
                try {
                    InputStream resourceAsStream = Thread.currentThread().getContextClassLoader().getResourceAsStream(dotName.replace('.', '/') + ".class");
                    if (resourceAsStream == null) {
                        if (resourceAsStream != null) {
                            resourceAsStream.close();
                            return;
                        }
                        return;
                    }
                    try {
                        ClassNode classNode = new ClassNode(589824);
                        new ClassReader(resourceAsStream).accept(classNode, 0);
                        Iterator it2 = classNode.methods.iterator();
                        while (it2.hasNext()) {
                            analyze(classNode, (MethodNode) it2.next(), hashSet);
                        }
                        if (resourceAsStream != null) {
                            resourceAsStream.close();
                        }
                    } catch (Throwable th) {
                        if (resourceAsStream != null) {
                            try {
                                resourceAsStream.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        }
                        throw th;
                    }
                } catch (AnalyzerException e) {
                    log.debug("Unable to analyze bytecode of class '" + dotName + "'", e);
                } catch (IOException e2) {
                    throw new UncheckedIOException("Reading bytecode of class '" + dotName + "' failed", e2);
                }
            }
        }
        for (String str : hashSet) {
            ClassInfo classByName = indexView.getClassByName(str);
            if (classByName != null && !classByName.hasDeclaredAnnotation(Langchain4jDotNames.STRUCTURED_PROMPT)) {
                log.warn("Class '" + str + "' is used in StructuredPromptProcessor but it is not annotated with @StructuredPrompt. This will likely result in an exception being thrown when the prompt is used.");
            }
        }
    }

    private void analyze(ClassNode classNode, MethodNode methodNode, final Set<String> set) throws AnalyzerException {
        new Analyzer(new SimpleVerifier(589824, Type.getObjectType(classNode.name), Type.getObjectType(classNode.superName), (List) classNode.interfaces.stream().map(Type::getObjectType).collect(Collectors.toList()), (classNode.access & 512) == 512) { // from class: io.quarkiverse.langchain4j.deployment.PromptProcessor.1
            public BasicValue naryOperation(AbstractInsnNode abstractInsnNode, List<? extends BasicValue> list) throws AnalyzerException {
                if (abstractInsnNode.getType() == 5) {
                    MethodInsnNode methodInsnNode = (MethodInsnNode) abstractInsnNode;
                    if (PromptProcessor.STRUCTURED_PROMPT_PROCESSOR_BINARY_NAME.equals(methodInsnNode.owner) && PromptProcessor.TO_PROMPT.equals(methodInsnNode.name) && PromptProcessor.TO_PROMPT_DESCRIPTOR.equals(methodInsnNode.desc)) {
                        BasicValue basicValue = list.get(0);
                        if (basicValue instanceof UnionValue) {
                            set.addAll((Collection) ((UnionValue) basicValue).union.stream().map((v0) -> {
                                return v0.getClassName();
                            }).collect(Collectors.toSet()));
                        } else {
                            set.add(basicValue.getType().getClassName());
                        }
                    }
                }
                return super.naryOperation(abstractInsnNode, list);
            }

            /* renamed from: newValue, reason: merged with bridge method [inline-methods] */
            public BasicValue m10newValue(Type type) {
                return UnionValue.create(super.newValue(type));
            }

            public BasicValue merge(BasicValue basicValue, BasicValue basicValue2) {
                return UnionValue.create(super.merge(basicValue, basicValue2), basicValue, basicValue2);
            }

            /* renamed from: naryOperation, reason: collision with other method in class */
            public /* bridge */ /* synthetic */ Value m9naryOperation(AbstractInsnNode abstractInsnNode, List list) throws AnalyzerException {
                return naryOperation(abstractInsnNode, (List<? extends BasicValue>) list);
            }
        }).analyze(classNode.name, methodNode);
    }
}
