package com.t4a.predict;

import com.google.cloud.vertexai.VertexAI;
import com.google.cloud.vertexai.api.Type;
import com.google.cloud.vertexai.generativeai.ChatSession;
import com.google.cloud.vertexai.generativeai.GenerativeModel;
import com.google.cloud.vertexai.generativeai.ResponseHandler;
import com.t4a.action.ExtendedPredictedAction;
import com.t4a.action.http.HttpPredictedAction;
import com.t4a.action.shell.ShellPredictedAction;
import com.t4a.api.AIAction;
import com.t4a.api.ActionType;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.openai.OpenAiChatModel;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.logging.Logger;
import org.springframework.context.annotation.ClassPathScanningCandidateComponentProvider;
import org.springframework.core.type.filter.AnnotationTypeFilter;

/* loaded from: input_file:com/t4a/predict/PredictionLoader.class */
public class PredictionLoader {
    private Map<String, AIAction> predictions = new HashMap();
    private StringBuffer actionNameList = new StringBuffer();
    private final String PREACTIONCMD = "here is my prompt - ";
    private final String ACTIONCMD = "- what action do you think we should take ";
    private final String OPEN_AIPRMT = "here is your prompt - prompt_str - here is you action- action_name(params_values) - what parameter should you pass to this function. give comma separated name=values only and nothing else";
    private final String POSTACTIONCMD = " - reply back with ";
    private final String NUMACTION = " action only";
    private final String NUMACTION_MULTIPROMPT = " actions only, in comma seperated list without any additional special characters";
    private ChatSession chat;
    private ChatSession chatExplain;
    private String projectId;
    private String location;
    private String modelName;
    private ChatLanguageModel openAiChatModel;
    private String openAiKey;
    private static final Logger log = Logger.getLogger(PredictionLoader.class.getName());
    private static PredictionLoader predictionLoader = null;

    private PredictionLoader() {
        initProp();
        VertexAI vertexAI = new VertexAI(this.projectId, this.location);
        try {
            GenerativeModel build = GenerativeModel.newBuilder().setModelName(this.modelName).setVertexAi(vertexAI).build();
            GenerativeModel build2 = GenerativeModel.newBuilder().setModelName(this.modelName).setVertexAi(vertexAI).build();
            this.chat = build.startChat();
            this.chatExplain = build2.startChat();
            vertexAI.close();
            if (this.openAiKey != null) {
                this.openAiChatModel = OpenAiChatModel.withApiKey(this.openAiKey);
            }
        } catch (Throwable th) {
            try {
                vertexAI.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    public String getProjectId() {
        return this.projectId;
    }

    public String getModelName() {
        return this.modelName;
    }

    public String getLocation() {
        return this.location;
    }

    public void initProp() {
        try {
            InputStream resourceAsStream = PredictionLoader.class.getClassLoader().getResourceAsStream("tools4ai.properties");
            try {
                Properties properties = new Properties();
                properties.load(resourceAsStream);
                this.projectId = properties.getProperty("projectId").trim();
                this.location = properties.getProperty("location").trim();
                this.modelName = properties.getProperty("modelName").trim();
                this.openAiKey = properties.getProperty("openAiKey").trim();
                log.info("projectId: " + this.projectId);
                log.info("location: " + this.location);
                log.info("modelName: " + this.modelName);
                if (resourceAsStream != null) {
                    resourceAsStream.close();
                }
            } finally {
            }
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    public List<AIAction> getPredictedAction(String str, int i) {
        ArrayList arrayList = new ArrayList();
        try {
            String text = ResponseHandler.getText(this.chat.sendMessage(buildPrompt(str, i)));
            String[] split = text.split(",");
            if (split.length <= 1) {
                split = text.split("\n");
            }
            for (String str2 : split) {
                arrayList.add(getAiAction(str2.trim()));
            }
            return arrayList;
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public String getCommaSeparatedKeys(Map<String, ?> map) {
        StringBuilder sb = new StringBuilder();
        for (String str : map.keySet()) {
            if (sb.length() > 0) {
                sb.append(", ");
            }
            sb.append(str);
        }
        return sb.toString();
    }

    public String getActionParams(AIAction aIAction, String str, AIPlatform aIPlatform, Map<String, Type> map) {
        String replace = "here is your prompt - prompt_str - here is you action- action_name(params_values) - what parameter should you pass to this function. give comma separated name=values only and nothing else".replace("prompt_str", str).replace("action_name", aIAction.getActionName()).replace("params_values", getCommaSeparatedKeys(map));
        log.info(replace);
        return this.openAiChatModel.generate(replace);
    }

    public String postActionProcessing(AIAction aIAction, String str, AIPlatform aIPlatform, Map<String, Type> map, String str2) {
        return this.openAiChatModel.generate(str + " result " + str2);
    }

    public AIAction getPredictedAction(String str, AIPlatform aIPlatform) {
        AIAction aIAction = null;
        try {
            if (AIPlatform.GEMINI == aIPlatform) {
                aIAction = getAiAction(ResponseHandler.getText(this.chat.sendMessage(buildPrompt(str, 1))));
            } else if (AIPlatform.OPENAI == aIPlatform) {
                String replace = this.openAiChatModel.generate(buildPromptForOpenAI(str, 1)).replace("()", "");
                aIAction = getAiAction(replace);
                if (aIAction == null) {
                    log.info("action not found , trying again");
                    String fetchActionNameFromList = fetchActionNameFromList(replace);
                    log.info("action name " + fetchActionNameFromList);
                    aIAction = getAiAction(fetchActionNameFromList);
                }
            }
            return aIAction;
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public String fetchActionNameFromList(String str) {
        String str2 = null;
        for (String str3 : this.actionNameList.toString().split(",")) {
            if (str3.equalsIgnoreCase(str)) {
                str2 = str3;
            }
        }
        return str2;
    }

    public AIAction getPredictedAction(String str) {
        return getPredictedAction(str, AIPlatform.GEMINI);
    }

    public String getPredictedActionMultiStep(String str) {
        try {
            String text = ResponseHandler.getText(this.chat.sendMessage(buildPromptMultiStep(str)));
            log.info(text);
            return text;
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public String explainAction(String str, String str2) {
        try {
            return ResponseHandler.getText(this.chatExplain.sendMessage("explain why this action " + str2 + " is appropriate for this command " + str + " out of all these actions " + ((Object) this.actionNameList)));
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    private AIAction getAiAction(String str) {
        log.info(" Trying to load " + str);
        AIAction aIAction = this.predictions.get(str);
        if (aIAction != null) {
            if (aIAction.getActionType() == ActionType.SHELL) {
                return (ShellPredictedAction) aIAction;
            }
            if (aIAction.getActionType() == ActionType.HTTP) {
                if (aIAction instanceof HttpPredictedAction) {
                    return (HttpPredictedAction) aIAction;
                }
                if (aIAction.getActionType() == ActionType.EXTEND) {
                    return (ExtendedPredictedAction) aIAction;
                }
            }
        }
        return aIAction;
    }

    public static PredictionLoader getInstance() {
        if (predictionLoader == null) {
            predictionLoader = new PredictionLoader();
            predictionLoader.processCP();
            predictionLoader.loadShellCommands();
            predictionLoader.loadHttpCommands();
            predictionLoader.loadSwaggerHttpActions();
        }
        return predictionLoader;
    }

    private void loadShellCommands() {
        try {
            new ShellPredictionLoader().load(this.predictions, this.actionNameList);
        } catch (LoaderException e) {
            log.warning(e.getMessage());
        }
    }

    private void loadSwaggerHttpActions() {
        try {
            new SwaggerPredictionLoader().load(this.predictions, this.actionNameList);
        } catch (LoaderException e) {
            log.warning(e.getMessage());
        }
    }

    private void loadHttpCommands() {
        try {
            new HttpRestPredictionLoader().load(this.predictions, this.actionNameList);
        } catch (LoaderException e) {
            log.warning(e.getMessage());
        }
    }

    public void processCP() {
        ClassPathScanningCandidateComponentProvider classPathScanningCandidateComponentProvider = new ClassPathScanningCandidateComponentProvider(true);
        classPathScanningCandidateComponentProvider.addIncludeFilter(new AnnotationTypeFilter(Predict.class));
        classPathScanningCandidateComponentProvider.addIncludeFilter(new AnnotationTypeFilter(ActivateLoader.class));
        classPathScanningCandidateComponentProvider.findCandidateComponents("*").stream().forEach(beanDefinition -> {
            try {
                Class<?> cls = Class.forName(beanDefinition.getBeanClassName());
                if (AIAction.class.isAssignableFrom(cls)) {
                    log.info("Class " + cls + " implements AIAction");
                    if (ExtendedPredictedAction.class.isAssignableFrom(cls)) {
                        log.warning("You cannot predict extended option implement AIAction instead" + cls);
                    } else {
                        addAction(cls);
                    }
                } else if (ExtendedPredictionLoader.class.isAssignableFrom(cls)) {
                    log.info("Class " + cls + " implements Loader");
                    loadFromLoader(cls);
                }
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        });
    }

    private void loadFromLoader(Class cls) throws NoSuchMethodException, InvocationTargetException, InstantiationException, IllegalAccessException {
        try {
            Map<String, ExtendedPredictedAction> extendedActions = ((ExtendedPredictionLoader) cls.getDeclaredConstructor(new Class[0]).newInstance(new Object[0])).getExtendedActions();
            for (String str : extendedActions.keySet()) {
                log.info(" names " + ((Object) this.actionNameList));
                this.actionNameList.append(str).append(",");
                this.predictions.put(str, extendedActions.get(str));
            }
        } catch (LoaderException e) {
            log.severe(e.getMessage() + " for " + cls.getName());
        }
    }

    private void addAction(Class cls) throws NoSuchMethodException, InvocationTargetException, InstantiationException, IllegalAccessException {
        AIAction aIAction = (AIAction) cls.getDeclaredConstructor(new Class[0]).newInstance(new Object[0]);
        String actionName = aIAction.getActionName();
        this.actionNameList.append(actionName + ",");
        this.predictions.put(actionName, aIAction);
    }

    public Map<String, AIAction> getPredictions() {
        return this.predictions;
    }

    public StringBuffer getActionNameList() {
        return this.actionNameList;
    }

    private String buildPrompt(String str, int i) {
        String str2 = "here is my prompt - " + str + "- what action do you think we should take " + this.actionNameList.toString() + " - reply back with " + i + (i > 1 ? " actions only, in comma seperated list without any additional special characters" : " action only");
        log.info(str2);
        return str2;
    }

    private String getModifiedActionName(StringBuffer stringBuffer) {
        String[] split = stringBuffer.toString().split(",");
        StringBuilder sb = new StringBuilder();
        for (String str : split) {
            sb.append(str).append("(),");
        }
        if (sb.length() > 0) {
            sb.setLength(sb.length() - 1);
        }
        return sb.toString();
    }

    private String buildPromptForOpenAI(String str, int i) {
        String str2 = "here is my prompt - " + str + "- what action do you think we should take " + getModifiedActionName(this.actionNameList) + " - reply back with " + i + (i > 1 ? " actions only, in comma seperated list without any additional special characters" : " action only");
        log.info(str2);
        return str2;
    }

    private String buildPromptMultiStep(String str) {
        return "break down this prompt into multiple prompts and associated action in comma separated list , this is your prompt - " + str + " - action list is here -" + ((Object) this.actionNameList) + " you will provide the result in this format - sub-prompt,action. If not action matches the sub-prompt please put blankAction";
    }
}
