/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.neuralsearch.processor.combination;

import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationUtil;
import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique;
import org.opensearch.neuralsearch.processor.explain.ExplanationUtils;

public class RRFScoreCombinationTechnique
implements ScoreCombinationTechnique,
ExplainableTechnique {
    @Generated
    private static final Logger log = LogManager.getLogger(RRFScoreCombinationTechnique.class);
    public static final String TECHNIQUE_NAME = "rrf";
    private static final Set<String> SUPPORTED_PARAMS = Set.of("weights");
    private final List<Float> weights;
    private static final Float ZERO_SCORE = Float.valueOf(0.0f);
    private final ScoreCombinationUtil scoreCombinationUtil;

    public RRFScoreCombinationTechnique(Map<String, Object> params, ScoreCombinationUtil combinationUtil) {
        this.scoreCombinationUtil = combinationUtil;
        this.scoreCombinationUtil.validateParams(params, SUPPORTED_PARAMS);
        this.weights = this.scoreCombinationUtil.getWeights(params);
    }

    @Override
    public float combine(float[] scores) {
        if (Objects.isNull(scores)) {
            throw new IllegalArgumentException("scores array cannot be null");
        }
        this.scoreCombinationUtil.validateIfWeightsMatchScores(scores, this.weights);
        float sumScores = 0.0f;
        float sumOfWeights = 0.0f;
        for (int indexOfSubQuery = 0; indexOfSubQuery < scores.length; ++indexOfSubQuery) {
            float score = scores[indexOfSubQuery];
            if (!((double)score >= 0.0)) continue;
            float weight = this.scoreCombinationUtil.getWeightForSubQuery(this.weights, indexOfSubQuery);
            sumScores += (score *= weight);
            sumOfWeights += weight;
        }
        if (sumOfWeights == 0.0f) {
            return ZERO_SCORE.floatValue();
        }
        return sumScores;
    }

    @Override
    public String techniqueName() {
        return TECHNIQUE_NAME;
    }

    @Override
    public String describe() {
        return ExplanationUtils.describeCombinationTechnique(TECHNIQUE_NAME, List.of());
    }

    @Generated
    public String toString() {
        return "RRFScoreCombinationTechnique(TECHNIQUE_NAME=rrf, weights=" + String.valueOf(this.weights) + ")";
    }
}

