package dev.ai4j.openai4j;

import com.google.gson.FieldNamingPolicy;
import com.google.gson.GsonBuilder;
import dev.ai4j.openai4j.chat.ChatCompletionRequest;
import dev.ai4j.openai4j.chat.ChatCompletionResponse;
import dev.ai4j.openai4j.completion.CompletionRequest;
import dev.ai4j.openai4j.completion.CompletionResponse;
import dev.ai4j.openai4j.embedding.EmbeddingRequest;
import dev.ai4j.openai4j.embedding.EmbeddingResponse;
import java.io.IOException;
import java.time.Duration;
import java.util.List;
import java.util.function.Function;
import okhttp3.MediaType;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener;
import okhttp3.sse.EventSources;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import retrofit2.Call;
import retrofit2.Callback;
import retrofit2.Response;
import retrofit2.Retrofit;
import retrofit2.converter.gson.GsonConverterFactory;

/* loaded from: input_file:dev/ai4j/openai4j/OpenAiService.class */
public class OpenAiService {
    private static final Logger log = LoggerFactory.getLogger(OpenAiService.class);
    private final String url;
    private final OkHttpClient okHttpClient;
    private final OpenAiApi openAiApi;

    /* loaded from: input_file:dev/ai4j/openai4j/OpenAiService$Builder.class */
    public static class Builder {
        private String url;
        private String apiKey;
        private Duration timeout;

        private Builder() {
            this.url = "https://api.openai.com/";
            this.timeout = Duration.ofSeconds(60L);
        }

        public Builder url(String str) {
            if (str == null || str.trim().isEmpty()) {
                throw new IllegalArgumentException("URL cannot be null or empty");
            }
            this.url = str.endsWith("/") ? str : str + "/";
            return this;
        }

        public Builder apiKey(String str) {
            if (str == null || str.trim().isEmpty()) {
                throw new IllegalArgumentException("API key cannot be null or empty. API keys can be generated here: https://platform.openai.com/account/api-keys");
            }
            this.apiKey = str;
            return this;
        }

        public Builder timeout(Duration duration) {
            if (duration == null) {
                throw new IllegalArgumentException("Timeout cannot be null");
            }
            this.timeout = duration;
            return this;
        }

        public OpenAiService build() {
            return new OpenAiService(this);
        }
    }

    public OpenAiService(String str) {
        this(builder().apiKey(str));
    }

    private OpenAiService(Builder builder) {
        this.url = builder.url;
        this.okHttpClient = new OkHttpClient.Builder().addInterceptor(new ApiKeyInsertingInterceptor(builder.apiKey)).addInterceptor(new RequestLoggingInterceptor()).addInterceptor(new ResponseLoggingInterceptor()).callTimeout(builder.timeout).build();
        this.openAiApi = (OpenAiApi) new Retrofit.Builder().baseUrl(builder.url).client(this.okHttpClient).addConverterFactory(GsonConverterFactory.create(new GsonBuilder().setFieldNamingPolicy(FieldNamingPolicy.LOWER_CASE_WITH_UNDERSCORES).create())).build().create(OpenAiApi.class);
    }

    public static Builder builder() {
        return new Builder();
    }

    public CompletionResponse getCompletions(CompletionRequest completionRequest) {
        if (completionRequest.stream() == null || !completionRequest.stream().booleanValue()) {
            return (CompletionResponse) execute(this.openAiApi.completion(completionRequest));
        }
        throw new IllegalArgumentException("Request parameter 'stream' with value 'true' is not compatible with getCompletion(...) method. If you need streaming, use one of streamCompletion(...) methods. If you do NOT need streaming, do not set 'stream' parameter in CompletionRequest, or set it to 'false'.");
    }

    @Experimental
    public String getCompletion(String str) {
        return getCompletions(CompletionRequest.builder().prompt(str).build()).text();
    }

    @Experimental
    public void getCompletionsAsync(CompletionRequest completionRequest, final ResponseHandler<CompletionResponse> responseHandler) {
        if (completionRequest.stream() != null && completionRequest.stream().booleanValue()) {
            throw new IllegalArgumentException("Request parameter 'stream' with value 'true' is not compatible with getCompletionAsync(...) method. If you need streaming, use one of streamCompletion(...) methods. If you do NOT need streaming, do not set 'stream' parameter in CompletionRequest, or set it to 'false'.");
        }
        this.openAiApi.completion(completionRequest).enqueue(new Callback<CompletionResponse>() { // from class: dev.ai4j.openai4j.OpenAiService.1
            public void onResponse(Call<CompletionResponse> call, Response<CompletionResponse> response) {
                if (response.isSuccessful()) {
                    responseHandler.onResponse(response.body());
                    return;
                }
                try {
                    responseHandler.onFailure(new RuntimeException(response.errorBody().string()));
                } catch (IOException e) {
                    responseHandler.onFailure(e);
                }
            }

            public void onFailure(Call<CompletionResponse> call, Throwable th) {
                responseHandler.onFailure(th);
            }
        });
    }

    @Experimental
    public void getCompletionAsync(String str, final ResponseHandler<String> responseHandler) {
        getCompletionsAsync(CompletionRequest.builder().prompt(str).build(), new ResponseHandler<CompletionResponse>() { // from class: dev.ai4j.openai4j.OpenAiService.2
            @Override // dev.ai4j.openai4j.ResponseHandler
            public void onResponse(CompletionResponse completionResponse) {
                responseHandler.onResponse(completionResponse.text());
            }

            @Override // dev.ai4j.openai4j.ResponseHandler
            public void onFailure(Throwable th) {
                responseHandler.onFailure(th);
            }
        });
    }

    @Experimental
    public void streamCompletions(CompletionRequest completionRequest, StreamingResponseHandler streamingResponseHandler) {
        if (completionRequest.stream() != null && !completionRequest.stream().booleanValue()) {
            throw new IllegalArgumentException("Request parameter 'stream' with value 'false' is not compatible with streamCompletion(...) method. If you do not need streaming, use one of getCompletion(...) or getCompletionAsync(...) methods. If you need streaming, do not set 'stream' parameter in CompletionRequest, or set it to 'true'.");
        }
        stream(CompletionRequest.builder().from(completionRequest).stream(true).build(), "v1/completions", CompletionResponse.class, (v0) -> {
            return v0.text();
        }, streamingResponseHandler);
    }

    @Experimental
    public void streamCompletion(String str, StreamingResponseHandler streamingResponseHandler) {
        streamCompletions(CompletionRequest.builder().prompt(str).build(), streamingResponseHandler);
    }

    public ChatCompletionResponse getChatCompletions(ChatCompletionRequest chatCompletionRequest) {
        if (chatCompletionRequest.stream() == null || !chatCompletionRequest.stream().booleanValue()) {
            return (ChatCompletionResponse) execute(this.openAiApi.chatCompletion(chatCompletionRequest));
        }
        throw new IllegalArgumentException("Request parameter 'stream' with value 'true' is not compatible with getChatCompletion(...) method. If you need streaming, use one of streamChatCompletion(...) methods. If you do NOT need streaming, do not set 'stream' parameter in ChatCompletionRequest, or set it to 'false'.");
    }

    @Experimental
    public String getChatCompletion(String str) {
        return getChatCompletions(ChatCompletionRequest.builder().addUserMessage(str).build()).content();
    }

    @Experimental
    public void getChatCompletionsAsync(ChatCompletionRequest chatCompletionRequest, final ResponseHandler<ChatCompletionResponse> responseHandler) {
        if (chatCompletionRequest.stream() != null && chatCompletionRequest.stream().booleanValue()) {
            throw new IllegalArgumentException("Request parameter 'stream' with value 'true' is not compatible with getChatCompletionAsync(...) method. If you need streaming, use one of streamChatCompletion(...) methods. If you do NOT need streaming, do not set 'stream' parameter in ChatCompletionRequest, or set it to 'false'.");
        }
        this.openAiApi.chatCompletion(chatCompletionRequest).enqueue(new Callback<ChatCompletionResponse>() { // from class: dev.ai4j.openai4j.OpenAiService.3
            public void onResponse(Call<ChatCompletionResponse> call, Response<ChatCompletionResponse> response) {
                if (response.isSuccessful()) {
                    responseHandler.onResponse(response.body());
                    return;
                }
                try {
                    responseHandler.onFailure(new RuntimeException(response.errorBody().string()));
                } catch (IOException e) {
                    responseHandler.onFailure(e);
                }
            }

            public void onFailure(Call<ChatCompletionResponse> call, Throwable th) {
                responseHandler.onFailure(th);
            }
        });
    }

    @Experimental
    public void getChatCompletionAsync(String str, final ResponseHandler<String> responseHandler) {
        getChatCompletionsAsync(ChatCompletionRequest.builder().addUserMessage(str).build(), new ResponseHandler<ChatCompletionResponse>() { // from class: dev.ai4j.openai4j.OpenAiService.4
            @Override // dev.ai4j.openai4j.ResponseHandler
            public void onResponse(ChatCompletionResponse chatCompletionResponse) {
                responseHandler.onResponse(chatCompletionResponse.content());
            }

            @Override // dev.ai4j.openai4j.ResponseHandler
            public void onFailure(Throwable th) {
                responseHandler.onFailure(th);
            }
        });
    }

    @Experimental
    public void streamChatCompletions(ChatCompletionRequest chatCompletionRequest, StreamingResponseHandler streamingResponseHandler) {
        if (chatCompletionRequest.stream() != null && !chatCompletionRequest.stream().booleanValue()) {
            throw new IllegalArgumentException("Request parameter 'stream' with value 'false' is not compatible with streamChatCompletion(...) method. If you do not need streaming, use one of getChatCompletion(...) or getChatCompletionAsync(...) methods. If you need streaming, do not set 'stream' parameter in ChatCompletionRequest, or set it to 'true'.");
        }
        stream(ChatCompletionRequest.builder().from(chatCompletionRequest).stream(true).build(), "v1/chat/completions", ChatCompletionResponse.class, chatCompletionResponse -> {
            return chatCompletionResponse.choices().get(0).delta().content();
        }, streamingResponseHandler);
    }

    @Experimental
    public void streamChatCompletion(String str, StreamingResponseHandler streamingResponseHandler) {
        stream(ChatCompletionRequest.builder().addUserMessage(str).stream(true).build(), "v1/chat/completions", ChatCompletionResponse.class, chatCompletionResponse -> {
            return chatCompletionResponse.choices().get(0).delta().content();
        }, streamingResponseHandler);
    }

    public EmbeddingResponse getEmbeddings(EmbeddingRequest embeddingRequest) {
        return (EmbeddingResponse) execute(this.openAiApi.embedding(embeddingRequest));
    }

    @Experimental
    public List<Float> getEmbedding(String str) {
        return getEmbeddings(EmbeddingRequest.builder().input(str).build()).embedding();
    }

    @Experimental
    public void getEmbeddingsAsync(EmbeddingRequest embeddingRequest, final ResponseHandler<EmbeddingResponse> responseHandler) {
        this.openAiApi.embedding(embeddingRequest).enqueue(new Callback<EmbeddingResponse>() { // from class: dev.ai4j.openai4j.OpenAiService.5
            public void onResponse(Call<EmbeddingResponse> call, Response<EmbeddingResponse> response) {
                responseHandler.onResponse(response.body());
            }

            public void onFailure(Call<EmbeddingResponse> call, Throwable th) {
                responseHandler.onFailure(th);
            }
        });
    }

    @Experimental
    public void getEmbeddingAsync(String str, final ResponseHandler<List<Float>> responseHandler) {
        this.openAiApi.embedding(EmbeddingRequest.builder().input(str).build()).enqueue(new Callback<EmbeddingResponse>() { // from class: dev.ai4j.openai4j.OpenAiService.6
            public void onResponse(Call<EmbeddingResponse> call, Response<EmbeddingResponse> response) {
                responseHandler.onResponse(((EmbeddingResponse) response.body()).embedding());
            }

            public void onFailure(Call<EmbeddingResponse> call, Throwable th) {
                responseHandler.onFailure(th);
            }
        });
    }

    private <Res> Res execute(Call<Res> call) {
        try {
            Response execute = call.execute();
            if (execute.isSuccessful()) {
                return (Res) execute.body();
            }
            throw new RuntimeException(execute.errorBody().string());
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    private <Req, Res> void stream(Req req, String str, final Class<Res> cls, final Function<Res, String> function, final StreamingResponseHandler streamingResponseHandler) {
        EventSources.createFactory(this.okHttpClient).newEventSource(new Request.Builder().url(this.url + str).post(RequestBody.create(Json.toJson(req), MediaType.get("application/json; charset=utf-8"))).build(), new EventSourceListener() { // from class: dev.ai4j.openai4j.OpenAiService.7
            private final StringBuilder completeResponseBuilder = new StringBuilder();

            public void onOpen(EventSource eventSource, okhttp3.Response response) {
                OpenAiService.log.trace("onOpen() {}", response);
            }

            public void onEvent(EventSource eventSource, String str2, String str3, String str4) {
                OpenAiService.log.trace("onEvent() data: {}", str4);
                if ("[DONE]".equals(str4)) {
                    streamingResponseHandler.onCompleteResponse(this.completeResponseBuilder.toString());
                    return;
                }
                try {
                    String str5 = (String) function.apply(Json.fromJson(str4, cls));
                    if (str5 != null) {
                        this.completeResponseBuilder.append(str5);
                        streamingResponseHandler.onPartialResponse(str5);
                    }
                } catch (Exception e) {
                    streamingResponseHandler.onFailure(e);
                }
            }

            public void onClosed(EventSource eventSource) {
                OpenAiService.log.trace("onClosed()");
            }

            public void onFailure(EventSource eventSource, Throwable th, okhttp3.Response response) {
                OpenAiService.log.trace("onFailure()\nThrowable: {}\nResponse: {}", th, response);
                if (th != null) {
                    streamingResponseHandler.onFailure(th);
                    return;
                }
                try {
                    streamingResponseHandler.onFailure(new RuntimeException(response.body().string()));
                } catch (IOException e) {
                    streamingResponseHandler.onFailure(e);
                }
            }
        });
    }
}
