package chat.octet.accordion.action.llama;

import chat.octet.accordion.action.AbstractAction;
import chat.octet.accordion.action.model.ActionConfig;
import chat.octet.accordion.action.model.ExecuteResult;
import chat.octet.accordion.exceptions.ActionException;
import chat.octet.accordion.utils.JsonUtils;
import chat.octet.model.Model;
import chat.octet.model.beans.CompletionResult;
import chat.octet.model.parameters.GenerateParameter;
import com.google.common.base.Preconditions;
import java.util.Optional;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.text.StringSubstitutor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:chat/octet/accordion/action/llama/LlamaAction.class */
public class LlamaAction extends AbstractAction {
    private static final Logger log = LoggerFactory.getLogger(LlamaAction.class);
    public static final String LLAMA_INPUT = "LLAMA_INPUT";
    public static final String LLAMA_OUTPUT = "LLAMA_OUTPUT";
    private final LlamaParameter params;
    private final GenerateParameter generateParams;
    private final Model model;

    public LlamaAction(ActionConfig actionConfig) {
        super(actionConfig);
        this.params = (LlamaParameter) actionConfig.getActionParams(LlamaParameter.class, "Llama parameter cannot be null.");
        Preconditions.checkArgument(this.params.getModelParameter() != null, "Llama model parameter cannot be null.");
        this.generateParams = (GenerateParameter) Optional.ofNullable(this.params.getGenerateParameter()).orElse(GenerateParameter.builder().build());
        this.model = new Model(this.params.getModelParameter());
    }

    @Override // chat.octet.accordion.action.ActionService
    public ExecuteResult execute() throws ActionException {
        CompletionResult completions;
        ExecuteResult executeResult = new ExecuteResult();
        try {
            String string = getInputParameter().getString(LLAMA_INPUT);
            if (StringUtils.isEmpty(string)) {
                log.warn("Llama input text is empty, skipping inference.");
            } else {
                String str = null;
                if (StringUtils.isNotEmpty(this.params.getSystem())) {
                    str = StringSubstitutor.replace(this.params.getSystem(), getInputParameter());
                }
                if (this.params.isChatMode()) {
                    if (!this.params.isMemory()) {
                        this.model.removeChatStatus(this.generateParams.getUser());
                    }
                    completions = this.model.chatCompletions(this.generateParams, str, string);
                } else {
                    completions = this.model.completions(this.generateParams, string);
                }
                executeResult.add(LLAMA_OUTPUT, completions.getContent());
                log.debug("Llama action execution finished, result: " + JsonUtils.toJson(completions));
            }
        } catch (Exception e) {
            setExecuteThrowable(new ActionException(e.getMessage(), e));
        }
        return executeResult;
    }

    @Override // chat.octet.accordion.action.AbstractAction, chat.octet.accordion.action.ActionService
    public void close() {
        if (this.model != null) {
            this.model.close();
            log.debug("Close llama model and release all resources.");
        }
    }
}
