package org.jkiss.dbeaver.model.ai;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Flow;
import org.jkiss.code.NotNull;
import org.jkiss.dbeaver.DBException;
import org.jkiss.dbeaver.Log;
import org.jkiss.dbeaver.model.ai.MessageChunk;
import org.jkiss.dbeaver.model.ai.completion.DAIChatMessage;
import org.jkiss.dbeaver.model.ai.completion.DAIChatRequest;
import org.jkiss.dbeaver.model.ai.completion.DAIChatRole;
import org.jkiss.dbeaver.model.ai.completion.DAICommandRequest;
import org.jkiss.dbeaver.model.ai.completion.DAICompletionChunk;
import org.jkiss.dbeaver.model.ai.completion.DAICompletionContext;
import org.jkiss.dbeaver.model.ai.completion.DAICompletionEngine;
import org.jkiss.dbeaver.model.ai.completion.DAICompletionRequest;
import org.jkiss.dbeaver.model.ai.completion.DAICompletionResponse;
import org.jkiss.dbeaver.model.ai.completion.DAITranslateRequest;
import org.jkiss.dbeaver.model.ai.format.IAIFormatter;
import org.jkiss.dbeaver.model.ai.metadata.MetadataProcessor;
import org.jkiss.dbeaver.model.ai.utils.AIUtils;
import org.jkiss.dbeaver.model.ai.utils.ThrowableSupplier;
import org.jkiss.dbeaver.model.runtime.DBRProgressMonitor;
import org.jkiss.dbeaver.model.sql.SQLUtils;

/* loaded from: input_file:org/jkiss/dbeaver/model/ai/AIAssistantImpl.class */
public class AIAssistantImpl implements AIAssistant {
    private static final int MAX_RETRIES = 3;
    private final AISettingsRegistry settingsRegistry = AISettingsRegistry.getInstance();
    private final AIEngineRegistry engineRegistry = AIEngineRegistry.getInstance();
    private final AIFormatterRegistry formatterRegistry = AIFormatterRegistry.getInstance();
    private final AIAssistantRegistry assistantRegistry = AIAssistantRegistry.getInstance();
    private static final Log log = Log.getLog(AIAssistantImpl.class);
    private static final MetadataProcessor metadataProcessor = MetadataProcessor.INSTANCE;

    @Override // org.jkiss.dbeaver.model.ai.AIAssistant
    @NotNull
    public Flow.Publisher<DAICompletionChunk> chat(@NotNull DBRProgressMonitor dBRProgressMonitor, @NotNull DAIChatRequest dAIChatRequest) throws DBException {
        DAICompletionEngine engine = dAIChatRequest.engine() != null ? dAIChatRequest.engine() : getActiveEngine();
        ArrayList arrayList = new ArrayList();
        if (dAIChatRequest.context() != null) {
            arrayList.add(DAIChatMessage.systemMessage(getSystemPrompt() + System.lineSeparator() + metadataProcessor.describeContext(dBRProgressMonitor, dAIChatRequest.context(), formatter(), AIUtils.getMaxRequestTokens(engine, dBRProgressMonitor))));
        }
        arrayList.addAll(dAIChatRequest.messages());
        return requestCompletionStream(engine, dBRProgressMonitor, new DAICompletionRequest(AIUtils.truncateMessages(true, arrayList, engine.getMaxContextSize(dBRProgressMonitor))));
    }

    @Override // org.jkiss.dbeaver.model.ai.AIAssistant
    @NotNull
    public String translateTextToSql(@NotNull DBRProgressMonitor dBRProgressMonitor, @NotNull DAITranslateRequest dAITranslateRequest) throws DBException {
        DAICompletionEngine engine = dAITranslateRequest.engine() != null ? dAITranslateRequest.engine() : getActiveEngine();
        DAIChatMessage dAIChatMessage = new DAIChatMessage(DAIChatRole.USER, dAITranslateRequest.text());
        return AITextUtils.convertToSQL(dAIChatMessage, processAndSplitCompletion(dBRProgressMonitor, dAITranslateRequest.context(), requestCompletion(engine, dBRProgressMonitor, new DAICompletionRequest(AIUtils.truncateMessages(true, List.of(DAIChatMessage.systemMessage(getSystemPrompt() + System.lineSeparator() + metadataProcessor.describeContext(dBRProgressMonitor, dAITranslateRequest.context(), formatter(), AIUtils.getMaxRequestTokens(engine, dBRProgressMonitor))), dAIChatMessage), engine.getMaxContextSize(dBRProgressMonitor)))).choices().get(0).text()), dAITranslateRequest.context().getExecutionContext().getDataSource());
    }

    @Override // org.jkiss.dbeaver.model.ai.AIAssistant
    @NotNull
    public CommandResult command(@NotNull DBRProgressMonitor dBRProgressMonitor, @NotNull DAICommandRequest dAICommandRequest) throws DBException {
        DAICompletionEngine engine = dAICommandRequest.engine() != null ? dAICommandRequest.engine() : getActiveEngine();
        MessageChunk[] processAndSplitCompletion = processAndSplitCompletion(dBRProgressMonitor, dAICommandRequest.context(), requestCompletion(engine, dBRProgressMonitor, new DAICompletionRequest(AIUtils.truncateMessages(true, List.of(DAIChatMessage.systemMessage(getSystemPrompt() + System.lineSeparator() + metadataProcessor.describeContext(dBRProgressMonitor, dAICommandRequest.context(), formatter(), AIUtils.getMaxRequestTokens(engine, dBRProgressMonitor))), DAIChatMessage.userMessage(dAICommandRequest.text())), engine.getMaxContextSize(dBRProgressMonitor)))).choices().get(0).text());
        String str = null;
        StringBuilder sb = new StringBuilder();
        for (MessageChunk messageChunk : processAndSplitCompletion) {
            if (messageChunk instanceof MessageChunk.Code) {
                str = ((MessageChunk.Code) messageChunk).text();
            } else if (messageChunk instanceof MessageChunk.Text) {
                sb.append(((MessageChunk.Text) messageChunk).text());
            }
        }
        return new CommandResult(str, sb.toString());
    }

    @Override // org.jkiss.dbeaver.model.ai.AIAssistant
    public boolean hasValidConfiguration() throws DBException {
        return getActiveEngine().hasValidConfiguration();
    }

    protected MessageChunk[] processAndSplitCompletion(@NotNull DBRProgressMonitor dBRProgressMonitor, @NotNull DAICompletionContext dAICompletionContext, @NotNull String str) throws DBException {
        return AITextUtils.splitIntoChunks(SQLUtils.getDialectFromDataSource(dAICompletionContext.getExecutionContext().getDataSource()), AIUtils.processCompletion(dBRProgressMonitor, dAICompletionContext.getExecutionContext(), dAICompletionContext.getScopeObject(), str, formatter(), true));
    }

    private static <T> T callWithRetry(ThrowableSupplier<T, DBException> throwableSupplier) throws DBException {
        for (int i = 0; i < MAX_RETRIES; i++) {
            try {
                return throwableSupplier.get();
            } catch (TooManyRequestsException unused) {
            }
        }
        throw new DBException("Request failed after 3 attempts");
    }

    protected DAICompletionEngine getActiveEngine() throws DBException {
        return this.engineRegistry.getCompletionEngine(this.settingsRegistry.getSettings().getActiveEngine());
    }

    protected DAICompletionResponse requestCompletion(@NotNull DAICompletionEngine dAICompletionEngine, @NotNull DBRProgressMonitor dBRProgressMonitor, @NotNull DAICompletionRequest dAICompletionRequest) throws DBException {
        try {
            if (dAICompletionEngine.isLoggingEnabled()) {
                log.debug("Requesting completion [request=" + String.valueOf(dAICompletionRequest) + "]");
            }
            DAICompletionResponse dAICompletionResponse = (DAICompletionResponse) callWithRetry(() -> {
                return dAICompletionEngine.requestCompletion(dBRProgressMonitor, dAICompletionRequest);
            });
            if (dAICompletionEngine.isLoggingEnabled()) {
                log.debug("Received completion [response=" + String.valueOf(dAICompletionResponse) + "]");
            }
            return dAICompletionResponse;
        } catch (Exception e) {
            log.error("Error requesting completion", e);
            if (e instanceof DBException) {
                throw e;
            }
            throw new DBException("Error requesting completion", e);
        }
    }

    private Flow.Publisher<DAICompletionChunk> requestCompletionStream(@NotNull DAICompletionEngine dAICompletionEngine, @NotNull DBRProgressMonitor dBRProgressMonitor, @NotNull DAICompletionRequest dAICompletionRequest) throws DBException {
        try {
            Flow.Publisher publisher = (Flow.Publisher) callWithRetry(() -> {
                return dAICompletionEngine.requestCompletionStream(dBRProgressMonitor, dAICompletionRequest);
            });
            return subscriber -> {
                if (!dAICompletionEngine.isLoggingEnabled()) {
                    publisher.subscribe(subscriber);
                } else {
                    log.debug("Requesting completion stream [request=" + String.valueOf(dAICompletionRequest) + "]");
                    publisher.subscribe(new LogSubscriber(log, subscriber));
                }
            };
        } catch (Exception e) {
            log.error("Error requesting completion stream", e);
            if (e instanceof DBException) {
                throw e;
            }
            throw new DBException("Error requesting completion stream", e);
        }
    }

    protected String getSystemPrompt() {
        return "You are SQL assistant. You must produce SQL code for given prompt.\nYou must produce valid SQL statement enclosed with Markdown code block and terminated with semicolon.\nAll database object names should be properly escaped according to the SQL dialect.\nAll comments MUST be placed before query outside markdown code block.\nBe polite.\n";
    }

    protected IAIFormatter formatter() throws DBException {
        return this.formatterRegistry.getFormatter(AIConstants.CORE_FORMATTER);
    }

    protected AIAssistant assistant() throws DBException {
        return this.assistantRegistry.getAssistant();
    }
}
