package org.jkiss.dbeaver.model.ai.impl;

import java.lang.reflect.InvocationTargetException;
import java.util.List;
import java.util.concurrent.Flow;
import org.jkiss.code.NotNull;
import org.jkiss.code.Nullable;
import org.jkiss.dbeaver.DBException;
import org.jkiss.dbeaver.Log;
import org.jkiss.dbeaver.model.ai.AIAssistant;
import org.jkiss.dbeaver.model.ai.AICommandRequest;
import org.jkiss.dbeaver.model.ai.AICommandResult;
import org.jkiss.dbeaver.model.ai.AIConstants;
import org.jkiss.dbeaver.model.ai.AIMessage;
import org.jkiss.dbeaver.model.ai.AIMessageType;
import org.jkiss.dbeaver.model.ai.AITextUtils;
import org.jkiss.dbeaver.model.ai.AITranslateRequest;
import org.jkiss.dbeaver.model.ai.engine.AIDatabaseContext;
import org.jkiss.dbeaver.model.ai.engine.AIEngine;
import org.jkiss.dbeaver.model.ai.engine.AIEngineRequest;
import org.jkiss.dbeaver.model.ai.engine.AIEngineResponse;
import org.jkiss.dbeaver.model.ai.engine.AIEngineResponseChunk;
import org.jkiss.dbeaver.model.ai.engine.TooManyRequestsException;
import org.jkiss.dbeaver.model.ai.impl.MessageChunk;
import org.jkiss.dbeaver.model.ai.prompt.AIPromptBuilder;
import org.jkiss.dbeaver.model.ai.prompt.AIPromptFormatter;
import org.jkiss.dbeaver.model.ai.registry.AIEngineRegistry;
import org.jkiss.dbeaver.model.ai.registry.AIFormatterRegistry;
import org.jkiss.dbeaver.model.ai.registry.AISettingsRegistry;
import org.jkiss.dbeaver.model.ai.utils.AIUtils;
import org.jkiss.dbeaver.model.ai.utils.DatabaseMetadataUtils;
import org.jkiss.dbeaver.model.ai.utils.ThrowableSupplier;
import org.jkiss.dbeaver.model.app.DBPWorkspace;
import org.jkiss.dbeaver.model.exec.DBExecUtils;
import org.jkiss.dbeaver.model.runtime.DBRProgressMonitor;
import org.jkiss.dbeaver.model.sql.SQLUtils;
import org.jkiss.dbeaver.utils.RuntimeUtils;

/* loaded from: input_file:org/jkiss/dbeaver/model/ai/impl/AIAssistantImpl.class */
public class AIAssistantImpl implements AIAssistant {
    private static final Log log = Log.getLog(AIAssistantImpl.class);
    private static final int MANY_REQUESTS_RETRIES = 3;
    private static final int MANY_REQUESTS_TIMEOUT = 500;
    private final AISettingsRegistry settingsRegistry = AISettingsRegistry.getInstance();
    private final AIEngineRegistry engineRegistry = AIEngineRegistry.getInstance();
    private final AIFormatterRegistry formatterRegistry = AIFormatterRegistry.getInstance();

    @Override // org.jkiss.dbeaver.model.ai.AIAssistant
    public void initialize(@NotNull DBPWorkspace dBPWorkspace) {
    }

    @Override // org.jkiss.dbeaver.model.ai.AIAssistant
    @NotNull
    public String translateTextToSql(@NotNull DBRProgressMonitor dBRProgressMonitor, @NotNull AITranslateRequest aITranslateRequest) throws DBException {
        AIEngine engine = aITranslateRequest.engine() != null ? aITranslateRequest.engine() : getActiveEngine();
        AIMessage aIMessage = new AIMessage(AIMessageType.USER, aITranslateRequest.text());
        return AITextUtils.convertToSQL(aIMessage, processAndSplitCompletion(dBRProgressMonitor, aITranslateRequest.context(), requestCompletion(engine, dBRProgressMonitor, new AIEngineRequest(AIUtils.truncateMessages(true, List.of(AIMessage.systemMessage(buildPrompt(dBRProgressMonitor, engine, aITranslateRequest.context()).addGoals("Translate natural language text to SQL.").addOutputFormats("Place any explanation or comments before the SQL code block.", "Provide the SQL query in a fenced Markdown code block.").build()), aIMessage), engine.getMaxContextSize(dBRProgressMonitor)))).variants().get(0)), aITranslateRequest.context().getExecutionContext().getDataSource());
    }

    @Override // org.jkiss.dbeaver.model.ai.AIAssistant
    @NotNull
    public AICommandResult command(@NotNull DBRProgressMonitor dBRProgressMonitor, @NotNull AICommandRequest aICommandRequest) throws DBException {
        AIEngine engine = aICommandRequest.engine() != null ? aICommandRequest.engine() : getActiveEngine();
        MessageChunk[] processAndSplitCompletion = processAndSplitCompletion(dBRProgressMonitor, aICommandRequest.context(), requestCompletion(engine, dBRProgressMonitor, new AIEngineRequest(AIUtils.truncateMessages(true, List.of(AIMessage.systemMessage(buildPrompt(dBRProgressMonitor, engine, aICommandRequest.context()).addGoals("Translate natural language text to SQL.").addOutputFormats("Place any explanation or comments before the SQL code block.", "Provide the SQL query in a fenced Markdown code block.").build()), AIMessage.userMessage(aICommandRequest.text())), engine.getMaxContextSize(dBRProgressMonitor)))).variants().get(0));
        String str = null;
        StringBuilder sb = new StringBuilder();
        for (MessageChunk messageChunk : processAndSplitCompletion) {
            if (messageChunk instanceof MessageChunk.Code) {
                str = ((MessageChunk.Code) messageChunk).text();
            } else if (messageChunk instanceof MessageChunk.Text) {
                sb.append(((MessageChunk.Text) messageChunk).text());
            }
        }
        return new AICommandResult(str, sb.toString());
    }

    @Override // org.jkiss.dbeaver.model.ai.AIAssistant
    public boolean hasValidConfiguration() throws DBException {
        return getActiveEngine().hasValidConfiguration();
    }

    protected MessageChunk[] processAndSplitCompletion(@NotNull DBRProgressMonitor dBRProgressMonitor, @NotNull AIDatabaseContext aIDatabaseContext, @NotNull String str) throws DBException {
        return AITextUtils.splitIntoChunks(SQLUtils.getDialectFromDataSource(aIDatabaseContext.getExecutionContext().getDataSource()), AIUtils.processCompletion(dBRProgressMonitor, aIDatabaseContext.getExecutionContext(), aIDatabaseContext.getScopeObject(), str, formatter(), true));
    }

    private static <T> T callWithRetry(ThrowableSupplier<T, DBException> throwableSupplier) throws DBException {
        int i = 0;
        while (i < MANY_REQUESTS_RETRIES) {
            try {
                return throwableSupplier.get();
            } catch (TooManyRequestsException unused) {
                i++;
                if (i < MANY_REQUESTS_RETRIES) {
                    log.debug("Too many engine requests. Retry after 500ms");
                    RuntimeUtils.pause(MANY_REQUESTS_TIMEOUT);
                }
            }
        }
        throw new DBException("Request failed after 3 attempts");
    }

    protected AIEngine getActiveEngine() throws DBException {
        return this.engineRegistry.getCompletionEngine(this.settingsRegistry.getSettings().activeEngine());
    }

    protected AIEngineResponse requestCompletion(@NotNull AIEngine aIEngine, @NotNull DBRProgressMonitor dBRProgressMonitor, @NotNull AIEngineRequest aIEngineRequest) throws DBException {
        try {
            if (aIEngine.isLoggingEnabled()) {
                log.debug("Requesting completion [request=" + String.valueOf(aIEngineRequest) + "]");
            }
            AIEngineResponse aIEngineResponse = (AIEngineResponse) callWithRetry(() -> {
                return aIEngine.requestCompletion(dBRProgressMonitor, aIEngineRequest);
            });
            if (aIEngine.isLoggingEnabled()) {
                log.debug("Received completion [response=" + String.valueOf(aIEngineResponse) + "]");
            }
            return aIEngineResponse;
        } catch (Exception e) {
            if (e instanceof DBException) {
                throw e;
            }
            throw new DBException("Error requesting completion", e);
        }
    }

    protected Flow.Publisher<AIEngineResponseChunk> requestCompletionStream(@NotNull AIEngine aIEngine, @NotNull DBRProgressMonitor dBRProgressMonitor, @NotNull AIEngineRequest aIEngineRequest) throws DBException {
        try {
            Flow.Publisher publisher = (Flow.Publisher) callWithRetry(() -> {
                return aIEngine.requestCompletionStream(dBRProgressMonitor, aIEngineRequest);
            });
            boolean isLoggingEnabled = aIEngine.isLoggingEnabled();
            return subscriber -> {
                if (!isLoggingEnabled) {
                    publisher.subscribe(subscriber);
                } else {
                    log.debug("Requesting completion stream [request=" + String.valueOf(aIEngineRequest) + "]");
                    publisher.subscribe(new LogSubscriber(log, subscriber));
                }
            };
        } catch (Exception e) {
            log.error("Error requesting completion stream", e);
            if (e instanceof DBException) {
                throw e;
            }
            throw new DBException("Error requesting completion stream", e);
        }
    }

    protected AIPromptFormatter formatter() throws DBException {
        return this.formatterRegistry.getFormatter(AIConstants.CORE_FORMATTER);
    }

    protected AIPromptBuilder buildPrompt(@NotNull DBRProgressMonitor dBRProgressMonitor, @NotNull AIEngine aIEngine, @Nullable AIDatabaseContext aIDatabaseContext) throws DBException {
        return buildPrompt(dBRProgressMonitor, aIEngine, formatter(), aIDatabaseContext);
    }

    protected AIPromptBuilder buildPrompt(@NotNull DBRProgressMonitor dBRProgressMonitor, @NotNull AIEngine aIEngine, @NotNull AIPromptFormatter aIPromptFormatter, @Nullable AIDatabaseContext aIDatabaseContext) throws DBException {
        AIPromptBuilder createForDataSource = AIPromptBuilder.createForDataSource(aIDatabaseContext != null ? aIDatabaseContext.getDataSource() : null, aIPromptFormatter);
        if (aIDatabaseContext != null) {
            DBExecUtils.tryExecuteRecover(dBRProgressMonitor, aIDatabaseContext.getExecutionContext().getDataSource(), dBRProgressMonitor2 -> {
                try {
                    describeDatabaseMetadata(dBRProgressMonitor, aIEngine, aIPromptFormatter, aIDatabaseContext, createForDataSource);
                } catch (DBException e) {
                    throw new InvocationTargetException(e);
                }
            });
        } else {
            describeDatabaseMetadata(dBRProgressMonitor, aIEngine, aIPromptFormatter, aIDatabaseContext, createForDataSource);
        }
        return createForDataSource;
    }

    protected void describeDatabaseMetadata(@NotNull DBRProgressMonitor dBRProgressMonitor, @NotNull AIEngine aIEngine, @Nullable AIDatabaseContext aIDatabaseContext, @NotNull AIPromptBuilder aIPromptBuilder) throws DBException {
        describeDatabaseMetadata(dBRProgressMonitor, aIEngine, formatter(), aIDatabaseContext, aIPromptBuilder);
    }

    protected void describeDatabaseMetadata(@NotNull DBRProgressMonitor dBRProgressMonitor, @NotNull AIEngine aIEngine, @NotNull AIPromptFormatter aIPromptFormatter, @Nullable AIDatabaseContext aIDatabaseContext, @NotNull AIPromptBuilder aIPromptBuilder) throws DBException {
        if (aIDatabaseContext != null) {
            aIPromptBuilder.addDatabaseSnapshot(DatabaseMetadataUtils.describeContext(dBRProgressMonitor, aIDatabaseContext, aIPromptFormatter, AIUtils.getMaxRequestTokens(aIEngine, dBRProgressMonitor)));
        }
    }
}
