package com.huaweicloud.pangu.dev.sdk.skill.document;

import com.huaweicloud.pangu.dev.sdk.api.llms.LLM;
import com.huaweicloud.pangu.dev.sdk.api.memory.bo.Document;
import com.huaweicloud.pangu.dev.sdk.exception.PanguDevSDKException;
import com.huaweicloud.pangu.dev.sdk.skill.AbstractDocSkill;
import com.huaweicloud.pangu.dev.sdk.template.PromptTemplate;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/huaweicloud/pangu/dev/sdk/skill/document/DocMapReduceSkill.class */
public class DocMapReduceSkill extends AbstractDocSkill {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) DocMapReduceSkill.class);
    private final LLM llm;
    private PromptTemplate mapPrompt;
    private PromptTemplate reducePrompt;

    public DocMapReduceSkill(LLM llm) {
        this.llm = llm;
    }

    public void setMapPrompt(String str) {
        this.mapPrompt = new PromptTemplate(str);
    }

    public void setReducePrompt(String str) {
        this.reducePrompt = new PromptTemplate(str);
    }

    @Override // com.huaweicloud.pangu.dev.sdk.skill.DocSkill
    public String executeWithDocs(List<Document> list, String str) {
        return execute(list, str, 2000);
    }

    public String execute(List<Document> list, String str, int i) {
        ArrayList arrayList = new ArrayList();
        list.forEach(document -> {
            HashMap hashMap = new HashMap();
            hashMap.put(AbstractDocSkill.PromptParam.DOCUMENT, document.getPageContent());
            hashMap.put(AbstractDocSkill.PromptParam.QUESTION, str);
            String answer = this.llm.ask(this.mapPrompt.format(hashMap)).getAnswer();
            log.debug("policy: mapreduce, map answer is: {}", answer);
            arrayList.add(Document.builder().pageContent(answer).build());
        });
        while (getReducePromptToken(arrayList, str) > i) {
            List<List<Document>> splitDocs = splitDocs(arrayList, str, i);
            arrayList.clear();
            Iterator<List<Document>> it = splitDocs.iterator();
            while (it.hasNext()) {
                arrayList.add(Document.builder().pageContent(reduce(it.next(), str)).build());
            }
        }
        return reduce(arrayList, str);
    }

    private String reduce(List<Document> list, String str) {
        String answer = this.llm.ask(this.reducePrompt.format(getReduceInputs(list, str))).getAnswer();
        log.debug("policy: mapreduce, reduce answer is: {}", answer);
        return answer;
    }

    private List<List<Document>> splitDocs(List<Document> list, String str, int i) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (Document document : list) {
            arrayList2.add(document);
            if (getReducePromptToken(arrayList2, str) >= i) {
                if (arrayList2.size() <= 2) {
                    throw new PanguDevSDKException("single document is too long");
                }
                arrayList.add(new ArrayList(arrayList2.subList(0, arrayList2.size() - 1)));
                arrayList2.clear();
                arrayList2.add(document);
            }
        }
        if (!arrayList2.isEmpty()) {
            arrayList.add(arrayList2);
        }
        return arrayList;
    }

    private int getReducePromptToken(List<Document> list, String str) {
        return this.reducePrompt.format(getReduceInputs(list, str)).length();
    }

    private Map<String, Object> getReduceInputs(List<Document> list, String str) {
        HashMap hashMap = new HashMap();
        hashMap.put(AbstractDocSkill.PromptParam.SUMMARIES, list);
        hashMap.put(AbstractDocSkill.PromptParam.QUESTION, str);
        return hashMap;
    }
}
