package io.quarkiverse.langchain4j.cohere.runtime;

import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.RetryUtils;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import dev.langchain4j.model.scoring.ScoringModel;
import io.quarkiverse.langchain4j.cohere.runtime.api.CohereApi;
import io.quarkiverse.langchain4j.cohere.runtime.api.RerankRequest;
import io.quarkiverse.langchain4j.cohere.runtime.api.RerankResponse;
import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder;
import jakarta.ws.rs.core.MultivaluedHashMap;
import jakarta.ws.rs.core.MultivaluedMap;
import java.net.URI;
import java.net.URISyntaxException;
import java.time.Duration;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import org.eclipse.microprofile.rest.client.ext.ClientHeadersFactory;

/* loaded from: input_file:io/quarkiverse/langchain4j/cohere/runtime/QuarkusCohereScoringModel.class */
public class QuarkusCohereScoringModel implements ScoringModel {
    private final CohereApi cohereApi;
    private final String model;
    private final Integer maxRetries;

    public QuarkusCohereScoringModel(String str, final String str2, String str3, Duration duration, Integer num) {
        this.model = str3;
        this.maxRetries = num;
        try {
            this.cohereApi = (CohereApi) QuarkusRestClientBuilder.newBuilder().baseUri(new URI(str)).clientHeadersFactory(new ClientHeadersFactory() { // from class: io.quarkiverse.langchain4j.cohere.runtime.QuarkusCohereScoringModel.1
                public MultivaluedMap<String, String> update(MultivaluedMap<String, String> multivaluedMap, MultivaluedMap<String, String> multivaluedMap2) {
                    MultivaluedHashMap multivaluedHashMap = new MultivaluedHashMap();
                    multivaluedHashMap.put("Authorization", Collections.singletonList("Bearer " + str2));
                    return multivaluedHashMap;
                }
            }).connectTimeout(duration.toSeconds(), TimeUnit.SECONDS).readTimeout(duration.toSeconds(), TimeUnit.SECONDS).build(CohereApi.class);
        } catch (URISyntaxException e) {
            throw new RuntimeException(e);
        }
    }

    public Response<List<Double>> scoreAll(List<TextSegment> list, String str) {
        RerankRequest rerankRequest = new RerankRequest(this.model, str, (List) list.stream().map((v0) -> {
            return v0.text();
        }).collect(Collectors.toList()));
        RerankResponse rerankResponse = (RerankResponse) RetryUtils.withRetry(() -> {
            return this.cohereApi.rerank(rerankRequest);
        }, this.maxRetries.intValue());
        return Response.from((List) rerankResponse.getResults().stream().sorted(Comparator.comparingInt((v0) -> {
            return v0.getIndex();
        })).map((v0) -> {
            return v0.getRelevanceScore();
        }).collect(Collectors.toList()), new TokenUsage(rerankResponse.getMeta().getBilledUnits().getSearchUnits()));
    }
}
