package org.jkiss.dbeaver.model.ai.engine.openai;

import com.theokanning.openai.completion.chat.ChatCompletionChunk;
import com.theokanning.openai.completion.chat.ChatCompletionRequest;
import com.theokanning.openai.completion.chat.ChatCompletionResult;
import com.theokanning.openai.completion.chat.ChatMessage;
import java.util.List;
import java.util.concurrent.Flow;
import org.jkiss.code.NotNull;
import org.jkiss.dbeaver.DBException;
import org.jkiss.dbeaver.Log;
import org.jkiss.dbeaver.model.ai.AIConstants;
import org.jkiss.dbeaver.model.ai.AIMessage;
import org.jkiss.dbeaver.model.ai.AIMessageType;
import org.jkiss.dbeaver.model.ai.engine.AIEngineRequest;
import org.jkiss.dbeaver.model.ai.engine.AIEngineResponse;
import org.jkiss.dbeaver.model.ai.engine.AIEngineResponseChunk;
import org.jkiss.dbeaver.model.ai.engine.BaseCompletionEngine;
import org.jkiss.dbeaver.model.ai.engine.LegacyAISettings;
import org.jkiss.dbeaver.model.ai.engine.openai.OpenAIBaseProperties;
import org.jkiss.dbeaver.model.ai.registry.AISettingsRegistry;
import org.jkiss.dbeaver.model.ai.utils.DisposableLazyValue;
import org.jkiss.dbeaver.model.runtime.DBRProgressMonitor;

/* loaded from: input_file:org/jkiss/dbeaver/model/ai/engine/openai/OpenAICompletionEngine.class */
public class OpenAICompletionEngine<PROPS extends OpenAIBaseProperties> extends BaseCompletionEngine<PROPS> {
    private static final Log log = Log.getLog(OpenAICompletionEngine.class);
    public static final String OPENAI_ENDPOINT = "https://api.openai.com/v1/";
    private final DisposableLazyValue<OpenAIClient, DBException> openAiService;
    private static volatile /* synthetic */ int[] $SWITCH_TABLE$org$jkiss$dbeaver$model$ai$AIMessageType;

    public OpenAICompletionEngine(AISettingsRegistry aISettingsRegistry) {
        super(aISettingsRegistry);
        this.openAiService = new DisposableLazyValue<OpenAIClient, DBException>() { // from class: org.jkiss.dbeaver.model.ai.engine.openai.OpenAICompletionEngine.1
            /* JADX INFO: Access modifiers changed from: protected */
            @Override // org.jkiss.dbeaver.model.ai.utils.LazyValue
            @NotNull
            public OpenAIClient initialize() throws DBException {
                return OpenAICompletionEngine.this.createClient();
            }

            /* JADX INFO: Access modifiers changed from: protected */
            @Override // org.jkiss.dbeaver.model.ai.utils.DisposableLazyValue
            public void onDispose(OpenAIClient openAIClient) {
                openAIClient.close();
            }
        };
    }

    @Override // org.jkiss.dbeaver.model.ai.engine.AIEngine
    public int getMaxContextSize(@NotNull DBRProgressMonitor dBRProgressMonitor) throws DBException {
        return OpenAIModel.getByName(getProperties().getModel()).getMaxTokens();
    }

    @Override // org.jkiss.dbeaver.model.ai.engine.AIEngine
    @NotNull
    public AIEngineResponse requestCompletion(@NotNull DBRProgressMonitor dBRProgressMonitor, @NotNull AIEngineRequest aIEngineRequest) throws DBException {
        return new AIEngineResponse(complete(dBRProgressMonitor, aIEngineRequest.messages()).getChoices().stream().map(chatCompletionChoice -> {
            return chatCompletionChoice.getMessage().getContent();
        }).toList());
    }

    @Override // org.jkiss.dbeaver.model.ai.engine.AIEngine
    @NotNull
    public Flow.Publisher<AIEngineResponseChunk> requestCompletionStream(@NotNull DBRProgressMonitor dBRProgressMonitor, @NotNull AIEngineRequest aIEngineRequest) throws DBException {
        Flow.Publisher<ChatCompletionChunk> createChatCompletionStream = this.openAiService.getInstance().createChatCompletionStream(dBRProgressMonitor, ChatCompletionRequest.builder().messages(fromMessages(aIEngineRequest.messages())).temperature(Double.valueOf(temperature())).frequencyPenalty(Double.valueOf(0.0d)).presencePenalty(Double.valueOf(0.0d)).maxTokens(Integer.valueOf(AIConstants.MAX_RESPONSE_TOKENS)).n(1).model(model()).stream(true).build());
        return subscriber -> {
            createChatCompletionStream.subscribe(new Flow.Subscriber<ChatCompletionChunk>() { // from class: org.jkiss.dbeaver.model.ai.engine.openai.OpenAICompletionEngine.2
                @Override // java.util.concurrent.Flow.Subscriber
                public void onSubscribe(Flow.Subscription subscription) {
                    subscriber.onSubscribe(subscription);
                }

                @Override // java.util.concurrent.Flow.Subscriber
                public void onNext(ChatCompletionChunk chatCompletionChunk) {
                    subscriber.onNext(new AIEngineResponseChunk(chatCompletionChunk.getChoices().stream().filter(chatCompletionChoice -> {
                        return chatCompletionChoice.getMessage() != null;
                    }).takeWhile(chatCompletionChoice2 -> {
                        return chatCompletionChoice2.getMessage().getContent() != null;
                    }).map(chatCompletionChoice3 -> {
                        return chatCompletionChoice3.getMessage().getContent();
                    }).toList()));
                }

                @Override // java.util.concurrent.Flow.Subscriber
                public void onError(Throwable th) {
                    subscriber.onError(th);
                }

                @Override // java.util.concurrent.Flow.Subscriber
                public void onComplete() {
                    subscriber.onComplete();
                }
            });
        };
    }

    @Override // org.jkiss.dbeaver.model.ai.registry.AISettingsEventListener
    public void onSettingsUpdate(@NotNull AISettingsRegistry aISettingsRegistry) {
        try {
            this.openAiService.dispose();
        } catch (DBException e) {
            log.error("Error disposing OpenAI service", e);
        }
    }

    @NotNull
    protected ChatCompletionResult complete(@NotNull DBRProgressMonitor dBRProgressMonitor, @NotNull List<AIMessage> list) throws DBException {
        return this.openAiService.getInstance().createChatCompletion(dBRProgressMonitor, ChatCompletionRequest.builder().messages(fromMessages(list)).temperature(Double.valueOf(temperature())).frequencyPenalty(Double.valueOf(0.0d)).presencePenalty(Double.valueOf(0.0d)).maxTokens(Integer.valueOf(AIConstants.MAX_RESPONSE_TOKENS)).n(1).model(model()).build());
    }

    @NotNull
    private static List<ChatMessage> fromMessages(@NotNull List<AIMessage> list) {
        return list.stream().map(aIMessage -> {
            return new ChatMessage(mapRole(aIMessage.getRole()), aIMessage.getContent());
        }).toList();
    }

    private static String mapRole(AIMessageType aIMessageType) {
        switch ($SWITCH_TABLE$org$jkiss$dbeaver$model$ai$AIMessageType()[aIMessageType.ordinal()]) {
            case 1:
                return "system";
            case 2:
                return "user";
            case 3:
                return "assistant";
            default:
                return null;
        }
    }

    protected OpenAIClient createClient() throws DBException {
        return new OpenAIClient(OPENAI_ENDPOINT, List.of(new OpenAIRequestFilter(getProperties().getToken())));
    }

    protected String model() throws DBException {
        return OpenAIModel.getByName(getProperties().getModel()).getName();
    }

    protected double temperature() throws DBException {
        return getProperties().getTemperature();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.jkiss.dbeaver.model.ai.engine.BaseCompletionEngine
    public PROPS getProperties() throws DBException {
        return (PROPS) ((LegacyAISettings) this.registry.getSettings().getEngineConfiguration(OpenAIConstants.OPENAI_ENGINE)).getProperties();
    }

    static /* synthetic */ int[] $SWITCH_TABLE$org$jkiss$dbeaver$model$ai$AIMessageType() {
        int[] iArr = $SWITCH_TABLE$org$jkiss$dbeaver$model$ai$AIMessageType;
        if (iArr != null) {
            return iArr;
        }
        int[] iArr2 = new int[AIMessageType.valuesCustom().length];
        try {
            iArr2[AIMessageType.ASSISTANT.ordinal()] = 3;
        } catch (NoSuchFieldError unused) {
        }
        try {
            iArr2[AIMessageType.ERROR.ordinal()] = 4;
        } catch (NoSuchFieldError unused2) {
        }
        try {
            iArr2[AIMessageType.SYSTEM.ordinal()] = 1;
        } catch (NoSuchFieldError unused3) {
        }
        try {
            iArr2[AIMessageType.USER.ordinal()] = 2;
        } catch (NoSuchFieldError unused4) {
        }
        $SWITCH_TABLE$org$jkiss$dbeaver$model$ai$AIMessageType = iArr2;
        return iArr2;
    }
}
