/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.neuralsearch.ml;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.QuestionAnsweringInputDataSet;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.model.RemoteModelConfig;
import org.opensearch.ml.common.output.model.ModelResultFilter;
import org.opensearch.neuralsearch.ml.AsymmetricModelDetector;
import org.opensearch.neuralsearch.processor.EmbeddingContentType;
import org.opensearch.neuralsearch.processor.InferenceRequest;

public class NeuralSearchMLInputBuilder {
    public static MLInput createTextEmbeddingInput(MLModel model, List<String> targetResponseFilters, List<String> inputText, InferenceRequest inferenceRequest) {
        boolean isAsymmetric = AsymmetricModelDetector.isAsymmetricModel(model);
        MLModelConfig modelConfig = model.getModelConfig();
        if (modelConfig instanceof RemoteModelConfig && !isAsymmetric) {
            throw new IllegalArgumentException("Remote models are only supported for asymmetric E5 text embedding");
        }
        if (isAsymmetric && modelConfig instanceof RemoteModelConfig) {
            return NeuralSearchMLInputBuilder.createAsymmetricRemoteInput(inputText, inferenceRequest);
        }
        MLAlgoParams mlAlgoParams = NeuralSearchMLInputBuilder.createMLAlgoParams(isAsymmetric, inferenceRequest);
        ModelResultFilter modelResultFilter = new ModelResultFilter(false, true, targetResponseFilters, null);
        TextDocsInputDataSet inputDataset = new TextDocsInputDataSet(inputText, modelResultFilter);
        return NeuralSearchMLInputBuilder.createLocalInput(FunctionName.TEXT_EMBEDDING, (MLInputDataset)inputDataset, mlAlgoParams);
    }

    private static MLInput createRemoteInput(FunctionName functionName, Map<String, Object> parameters) {
        HashMap<String, String> stringParameters = new HashMap<String, String>();
        for (Map.Entry<String, Object> entry : parameters.entrySet()) {
            stringParameters.put(entry.getKey(), entry.getValue() != null ? entry.getValue().toString() : null);
        }
        RemoteInferenceInputDataSet inputDataset = new RemoteInferenceInputDataSet(stringParameters);
        return MLInput.builder().algorithm(functionName).inputDataset((MLInputDataset)inputDataset).build();
    }

    private static MLInput createLocalInput(FunctionName functionName, MLInputDataset inputDataset, MLAlgoParams mlAlgoParams) {
        return new MLInput(functionName, mlAlgoParams, inputDataset);
    }

    private static MLAlgoParams createMLAlgoParams(boolean isAsymmetric, InferenceRequest inferenceRequest) {
        if (!isAsymmetric) {
            return inferenceRequest.getMlAlgoParams();
        }
        EmbeddingContentType contentType = inferenceRequest.getEmbeddingContentType();
        if (contentType == null) {
            throw new IllegalArgumentException("embeddingContentType must be set for asymmetric local models");
        }
        MLAlgoParams presetParams = inferenceRequest.getMlAlgoParams();
        if (presetParams != null && !(presetParams instanceof AsymmetricTextEmbeddingParameters)) {
            throw new IllegalArgumentException("MLAlgoParams must be AsymmetricTextEmbeddingParameters for asymmetric models");
        }
        return (presetParams != null ? ((AsymmetricTextEmbeddingParameters)presetParams).toBuilder() : AsymmetricTextEmbeddingParameters.builder()).embeddingContentType(contentType.toMLContentType()).build();
    }

    public static MLInput createTextSimilarityInput(String query, List<String> inputText) {
        TextSimilarityInputDataSet inputDataset = new TextSimilarityInputDataSet(query, inputText);
        return NeuralSearchMLInputBuilder.createLocalInput(FunctionName.TEXT_SIMILARITY, (MLInputDataset)inputDataset, null);
    }

    public static MLInput createQuestionAnsweringInput(String question, String context) {
        QuestionAnsweringInputDataSet inputDataset = new QuestionAnsweringInputDataSet(question, context);
        return NeuralSearchMLInputBuilder.createLocalInput(FunctionName.QUESTION_ANSWERING, (MLInputDataset)inputDataset, null);
    }

    public static MLInput createRemoteHighlightingInput(Map<String, String> parameters) {
        return NeuralSearchMLInputBuilder.createRemoteInput(FunctionName.REMOTE, new HashMap<String, Object>(parameters));
    }

    public static MLInput createBatchHighlightingInput(List<Map<String, String>> batchRequests) {
        try {
            HashMap<String, Object> parameters = new HashMap<String, Object>();
            XContentBuilder builder = XContentFactory.jsonBuilder();
            builder.startArray();
            for (Map<String, String> request : batchRequests) {
                builder.startObject().field("question", request.get("question")).field("context", request.get("context")).endObject();
            }
            builder.endArray();
            parameters.put("inputs", builder.toString());
            return NeuralSearchMLInputBuilder.createRemoteInput(FunctionName.REMOTE, parameters);
        }
        catch (Exception e) {
            throw new IllegalStateException("Failed to create batch highlighting ML input", e);
        }
    }

    public static MLInput createSingleRemoteHighlightingInput(String question, String context) {
        try {
            HashMap<String, Object> parameters = new HashMap<String, Object>();
            XContentBuilder builder = XContentFactory.jsonBuilder();
            builder.startArray();
            builder.startObject().field("question", question).field("context", context).endObject();
            builder.endArray();
            parameters.put("inputs", builder.toString());
            return NeuralSearchMLInputBuilder.createRemoteInput(FunctionName.REMOTE, parameters);
        }
        catch (Exception e) {
            throw new IllegalStateException("Failed to create remote highlighting ML input", e);
        }
    }

    private static MLInput createAsymmetricRemoteInput(List<String> inputText, InferenceRequest inferenceRequest) {
        try {
            HashMap<String, String> parameters = new HashMap<String, String>();
            XContentBuilder textsBuilder = XContentFactory.jsonBuilder();
            textsBuilder.startArray();
            for (String text : inputText) {
                textsBuilder.value(text);
            }
            textsBuilder.endArray();
            parameters.put("texts", textsBuilder.toString());
            EmbeddingContentType contentType = inferenceRequest.getEmbeddingContentType();
            parameters.put("content_type", contentType.toString().toLowerCase(Locale.ROOT));
            RemoteInferenceInputDataSet inputDataset = new RemoteInferenceInputDataSet(parameters);
            return MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset((MLInputDataset)inputDataset).build();
        }
        catch (Exception e) {
            throw new IllegalStateException("Failed to create asymmetric remote ML input", e);
        }
    }

    public static MLInput createMultimodalInputFromMap(MLModel model, List<String> targetResponseFilters, Map<String, String> inputObjects, InferenceRequest inferenceRequest) {
        EmbeddingContentType contentType;
        ArrayList<String> inputText = new ArrayList<String>();
        inputText.add(inputObjects.get("inputText"));
        if (inputObjects.containsKey("inputImage")) {
            inputText.add(inputObjects.get("inputImage"));
        }
        boolean isAsymmetric = AsymmetricModelDetector.isAsymmetricModel(model);
        MLModelConfig modelConfig = model.getModelConfig();
        if (isAsymmetric && modelConfig instanceof RemoteModelConfig) {
            return NeuralSearchMLInputBuilder.createAsymmetricRemoteInput(inputText, inferenceRequest);
        }
        AsymmetricTextEmbeddingParameters mlAlgoParams = null;
        if (isAsymmetric && (contentType = inferenceRequest.getEmbeddingContentType()) != null) {
            mlAlgoParams = AsymmetricTextEmbeddingParameters.builder().embeddingContentType(contentType.toMLContentType()).build();
        }
        ModelResultFilter modelResultFilter = new ModelResultFilter(false, true, targetResponseFilters, null);
        TextDocsInputDataSet inputDataset = new TextDocsInputDataSet(inputText, modelResultFilter);
        return new MLInput(FunctionName.TEXT_EMBEDDING, (MLAlgoParams)mlAlgoParams, (MLInputDataset)inputDataset);
    }
}

