package dev.langchain4j.model.huggingface;

import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.DimensionAwareEmbeddingModel;
import dev.langchain4j.model.huggingface.client.EmbeddingRequest;
import dev.langchain4j.model.huggingface.client.HuggingFaceClient;
import dev.langchain4j.model.huggingface.spi.HuggingFaceClientFactory;
import dev.langchain4j.model.huggingface.spi.HuggingFaceEmbeddingModelBuilderFactory;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.spi.ServiceHelper;
import java.time.Duration;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;

/* loaded from: input_file:dev/langchain4j/model/huggingface/HuggingFaceEmbeddingModel.class */
public class HuggingFaceEmbeddingModel extends DimensionAwareEmbeddingModel {
    private static final Duration DEFAULT_TIMEOUT = Duration.ofSeconds(15);
    private final HuggingFaceClient client;
    private final boolean waitForModel;
    private final String modelId;

    /* loaded from: input_file:dev/langchain4j/model/huggingface/HuggingFaceEmbeddingModel$HuggingFaceEmbeddingModelBuilder.class */
    public static class HuggingFaceEmbeddingModelBuilder {
        private String accessToken;
        private String modelId;
        private Boolean waitForModel;
        private Duration timeout;

        public HuggingFaceEmbeddingModelBuilder accessToken(String str) {
            this.accessToken = str;
            return this;
        }

        public HuggingFaceEmbeddingModelBuilder modelId(String str) {
            this.modelId = str;
            return this;
        }

        public HuggingFaceEmbeddingModelBuilder waitForModel(Boolean bool) {
            this.waitForModel = bool;
            return this;
        }

        public HuggingFaceEmbeddingModelBuilder timeout(Duration duration) {
            this.timeout = duration;
            return this;
        }

        public HuggingFaceEmbeddingModel build() {
            return new HuggingFaceEmbeddingModel(this.accessToken, this.modelId, this.waitForModel, this.timeout);
        }

        public String toString() {
            return "HuggingFaceEmbeddingModel.HuggingFaceEmbeddingModelBuilder(accessToken=" + this.accessToken + ", modelId=" + this.modelId + ", waitForModel=" + this.waitForModel + ", timeout=" + this.timeout + ")";
        }
    }

    public HuggingFaceEmbeddingModel(final String str, final String str2, Boolean bool, final Duration duration) {
        if (str == null || str.trim().isEmpty()) {
            throw new IllegalArgumentException("HuggingFace access token must be defined. It can be generated here: https://huggingface.co/settings/tokens");
        }
        this.client = FactoryCreator.FACTORY.create(new HuggingFaceClientFactory.Input() { // from class: dev.langchain4j.model.huggingface.HuggingFaceEmbeddingModel.1
            @Override // dev.langchain4j.model.huggingface.spi.HuggingFaceClientFactory.Input
            public String apiKey() {
                return str;
            }

            @Override // dev.langchain4j.model.huggingface.spi.HuggingFaceClientFactory.Input
            public String modelId() {
                return str2 == null ? HuggingFaceModelName.SENTENCE_TRANSFORMERS_ALL_MINI_LM_L6_V2 : str2;
            }

            @Override // dev.langchain4j.model.huggingface.spi.HuggingFaceClientFactory.Input
            public Duration timeout() {
                return duration == null ? HuggingFaceEmbeddingModel.DEFAULT_TIMEOUT : duration;
            }
        });
        this.waitForModel = bool == null || bool.booleanValue();
        this.modelId = str2;
    }

    public Response<List<Embedding>> embedAll(List<TextSegment> list) {
        return embedTexts((List) list.stream().map((v0) -> {
            return v0.text();
        }).collect(Collectors.toList()));
    }

    private Response<List<Embedding>> embedTexts(List<String> list) {
        return Response.from((List) this.client.embed(new EmbeddingRequest(list, this.waitForModel)).stream().map(Embedding::from).collect(Collectors.toList()));
    }

    public static HuggingFaceEmbeddingModel withAccessToken(String str) {
        return builder().accessToken(str).build();
    }

    public static HuggingFaceEmbeddingModelBuilder builder() {
        Iterator it = ServiceHelper.loadFactories(HuggingFaceEmbeddingModelBuilderFactory.class).iterator();
        return it.hasNext() ? ((HuggingFaceEmbeddingModelBuilderFactory) it.next()).get() : new HuggingFaceEmbeddingModelBuilder();
    }
}
