package org.apache.ignite.ml.selection.cv;

import java.lang.invoke.SerializedLambda;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.function.Function;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.lang.IgniteBiPredicate;
import org.apache.ignite.ml.Model;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder;
import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.selection.paramgrid.ParamGrid;
import org.apache.ignite.ml.selection.paramgrid.ParameterSetGenerator;
import org.apache.ignite.ml.selection.scoring.cursor.CacheBasedLabelPairCursor;
import org.apache.ignite.ml.selection.scoring.cursor.LabelPairCursor;
import org.apache.ignite.ml.selection.scoring.cursor.LocalLabelPairCursor;
import org.apache.ignite.ml.selection.scoring.metric.Metric;
import org.apache.ignite.ml.selection.split.mapper.SHA256UniformMapper;
import org.apache.ignite.ml.selection.split.mapper.UniformMapper;
import org.apache.ignite.ml.trainers.DatasetTrainer;

/* loaded from: input_file:org/apache/ignite/ml/selection/cv/CrossValidation.class */
public class CrossValidation<M extends Model<Vector, L>, L, K, V> {
    public double[] score(DatasetTrainer<M, L> datasetTrainer, Metric<L> metric, Ignite ignite, IgniteCache<K, V> igniteCache, IgniteBiFunction<K, V, Vector> igniteBiFunction, IgniteBiFunction<K, V, L> igniteBiFunction2, int i) {
        return score(datasetTrainer, metric, ignite, igniteCache, (obj, obj2) -> {
            return true;
        }, igniteBiFunction, igniteBiFunction2, new SHA256UniformMapper(), i);
    }

    public double[] score(DatasetTrainer<M, L> datasetTrainer, Metric<L> metric, Ignite ignite, IgniteCache<K, V> igniteCache, IgniteBiPredicate<K, V> igniteBiPredicate, IgniteBiFunction<K, V, Vector> igniteBiFunction, IgniteBiFunction<K, V, L> igniteBiFunction2, int i) {
        return score(datasetTrainer, metric, ignite, igniteCache, igniteBiPredicate, igniteBiFunction, igniteBiFunction2, new SHA256UniformMapper(), i);
    }

    public CrossValidationResult score(DatasetTrainer<M, L> datasetTrainer, Metric<L> metric, Ignite ignite, IgniteCache<K, V> igniteCache, IgniteBiPredicate<K, V> igniteBiPredicate, IgniteBiFunction<K, V, Vector> igniteBiFunction, IgniteBiFunction<K, V, L> igniteBiFunction2, int i, ParamGrid paramGrid) {
        List<Double[]> generate = new ParameterSetGenerator(paramGrid.getParamValuesByParamIdx()).generate();
        CrossValidationResult crossValidationResult = new CrossValidationResult();
        generate.forEach(dArr -> {
            String str;
            Method method;
            HashMap hashMap = new HashMap();
            for (int i2 = 0; i2 < dArr.length; i2++) {
                String paramNameByIndex = paramGrid.getParamNameByIndex(i2);
                Double d = dArr[i2];
                hashMap.put(paramNameByIndex, d);
                try {
                    str = "with" + paramNameByIndex.substring(0, 1).toUpperCase() + paramNameByIndex.substring(1);
                    method = null;
                    for (Method method2 : datasetTrainer.getClass().getDeclaredMethods()) {
                        if (method2.getName().equals(str)) {
                            method = method2;
                        }
                    }
                } catch (IllegalAccessException | NoSuchMethodException | InvocationTargetException e) {
                    e.printStackTrace();
                }
                if (method == null) {
                    throw new NoSuchMethodException(str);
                    break;
                }
                method.invoke(datasetTrainer, d);
            }
            double[] score = score(datasetTrainer, metric, ignite, igniteCache, igniteBiPredicate, igniteBiFunction, igniteBiFunction2, new SHA256UniformMapper(), i);
            crossValidationResult.addScores(score, hashMap);
            if (Arrays.stream(score).average().orElse(Double.MIN_VALUE) > crossValidationResult.getBestAvgScore()) {
                crossValidationResult.setBestScore(score);
                crossValidationResult.setBestHyperParams(hashMap);
                System.out.println(hashMap.toString());
            }
        });
        return crossValidationResult;
    }

    public double[] score(DatasetTrainer<M, L> datasetTrainer, Metric<L> metric, Ignite ignite, IgniteCache<K, V> igniteCache, IgniteBiPredicate<K, V> igniteBiPredicate, IgniteBiFunction<K, V, Vector> igniteBiFunction, IgniteBiFunction<K, V, L> igniteBiFunction2, UniformMapper<K, V> uniformMapper, int i) {
        return score(datasetTrainer, igniteBiPredicate2 -> {
            return new CacheBasedDatasetBuilder(ignite, igniteCache, (obj, obj2) -> {
                return igniteBiPredicate.apply(obj, obj2) && igniteBiPredicate2.apply(obj, obj2);
            });
        }, (igniteBiPredicate3, model) -> {
            return new CacheBasedLabelPairCursor(igniteCache, (obj, obj2) -> {
                return igniteBiPredicate.apply(obj, obj2) && !igniteBiPredicate3.apply(obj, obj2);
            }, igniteBiFunction, igniteBiFunction2, model);
        }, igniteBiFunction, igniteBiFunction2, metric, uniformMapper, i);
    }

    public double[] score(DatasetTrainer<M, L> datasetTrainer, Metric<L> metric, Map<K, V> map, int i, IgniteBiFunction<K, V, Vector> igniteBiFunction, IgniteBiFunction<K, V, L> igniteBiFunction2, int i2) {
        return score(datasetTrainer, metric, map, (obj, obj2) -> {
            return true;
        }, i, igniteBiFunction, igniteBiFunction2, new SHA256UniformMapper(), i2);
    }

    public double[] score(DatasetTrainer<M, L> datasetTrainer, Metric<L> metric, Map<K, V> map, IgniteBiPredicate<K, V> igniteBiPredicate, int i, IgniteBiFunction<K, V, Vector> igniteBiFunction, IgniteBiFunction<K, V, L> igniteBiFunction2, int i2) {
        return score(datasetTrainer, metric, map, igniteBiPredicate, i, igniteBiFunction, igniteBiFunction2, new SHA256UniformMapper(), i2);
    }

    public double[] score(DatasetTrainer<M, L> datasetTrainer, Metric<L> metric, Map<K, V> map, IgniteBiPredicate<K, V> igniteBiPredicate, int i, IgniteBiFunction<K, V, Vector> igniteBiFunction, IgniteBiFunction<K, V, L> igniteBiFunction2, UniformMapper<K, V> uniformMapper, int i2) {
        return score(datasetTrainer, igniteBiPredicate2 -> {
            return new LocalDatasetBuilder(map, (obj, obj2) -> {
                return igniteBiPredicate.apply(obj, obj2) && igniteBiPredicate2.apply(obj, obj2);
            }, i);
        }, (igniteBiPredicate3, model) -> {
            return new LocalLabelPairCursor(map, (obj, obj2) -> {
                return igniteBiPredicate.apply(obj, obj2) && !igniteBiPredicate3.apply(obj, obj2);
            }, igniteBiFunction, igniteBiFunction2, model);
        }, igniteBiFunction, igniteBiFunction2, metric, uniformMapper, i2);
    }

    private double[] score(DatasetTrainer<M, L> datasetTrainer, Function<IgniteBiPredicate<K, V>, DatasetBuilder<K, V>> function, BiFunction<IgniteBiPredicate<K, V>, M, LabelPairCursor<L>> biFunction, IgniteBiFunction<K, V, Vector> igniteBiFunction, IgniteBiFunction<K, V, L> igniteBiFunction2, Metric<L> metric, UniformMapper<K, V> uniformMapper, int i) {
        double[] dArr = new double[i];
        double d = 1.0d / i;
        for (int i2 = 0; i2 < i; i2++) {
            double d2 = d * i2;
            double d3 = d * (i2 + 1);
            IgniteBiPredicate<K, V> igniteBiPredicate = (obj, obj2) -> {
                double map = uniformMapper.map(obj, obj2);
                return map < d2 || map > d3;
            };
            try {
                LabelPairCursor<L> apply = biFunction.apply(igniteBiPredicate, datasetTrainer.fit(function.apply(igniteBiPredicate), igniteBiFunction, igniteBiFunction2));
                Throwable th = null;
                try {
                    try {
                        dArr[i2] = metric.score(apply.iterator());
                        if (apply != null) {
                            if (0 != 0) {
                                try {
                                    apply.close();
                                } catch (Throwable th2) {
                                    th.addSuppressed(th2);
                                }
                            } else {
                                apply.close();
                            }
                        }
                    } finally {
                    }
                } finally {
                }
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        }
        return dArr;
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case 366645849:
                if (implMethodName.equals("lambda$score$2a44d321$1")) {
                    z = true;
                    break;
                }
                break;
            case 562596546:
                if (implMethodName.equals("lambda$null$578a1759$1")) {
                    z = 2;
                    break;
                }
                break;
            case 562596547:
                if (implMethodName.equals("lambda$null$578a1759$2")) {
                    z = 3;
                    break;
                }
                break;
            case 562596548:
                if (implMethodName.equals("lambda$null$578a1759$3")) {
                    z = 5;
                    break;
                }
                break;
            case 562596549:
                if (implMethodName.equals("lambda$null$578a1759$4")) {
                    z = 6;
                    break;
                }
                break;
            case 1186044829:
                if (implMethodName.equals("lambda$score$7d2823c0$1")) {
                    z = 4;
                    break;
                }
                break;
            case 1794888986:
                if (implMethodName.equals("lambda$score$6cbca034$1")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/lang/IgniteBiPredicate") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Z") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/selection/cv/CrossValidation") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Z")) {
                    return (obj, obj2) -> {
                        return true;
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/lang/IgniteBiPredicate") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Z") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/selection/cv/CrossValidation") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/ignite/ml/selection/split/mapper/UniformMapper;DDLjava/lang/Object;Ljava/lang/Object;)Z")) {
                    UniformMapper uniformMapper = (UniformMapper) serializedLambda.getCapturedArg(0);
                    double doubleValue = ((Double) serializedLambda.getCapturedArg(1)).doubleValue();
                    double doubleValue2 = ((Double) serializedLambda.getCapturedArg(2)).doubleValue();
                    return (obj3, obj22) -> {
                        double map = uniformMapper.map(obj3, obj22);
                        return map < doubleValue || map > doubleValue2;
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/lang/IgniteBiPredicate") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Z") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/selection/cv/CrossValidation") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/ignite/lang/IgniteBiPredicate;Lorg/apache/ignite/lang/IgniteBiPredicate;Ljava/lang/Object;Ljava/lang/Object;)Z")) {
                    IgniteBiPredicate igniteBiPredicate = (IgniteBiPredicate) serializedLambda.getCapturedArg(0);
                    IgniteBiPredicate igniteBiPredicate2 = (IgniteBiPredicate) serializedLambda.getCapturedArg(1);
                    return (obj4, obj23) -> {
                        return igniteBiPredicate.apply(obj4, obj23) && igniteBiPredicate2.apply(obj4, obj23);
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/lang/IgniteBiPredicate") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Z") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/selection/cv/CrossValidation") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/ignite/lang/IgniteBiPredicate;Lorg/apache/ignite/lang/IgniteBiPredicate;Ljava/lang/Object;Ljava/lang/Object;)Z")) {
                    IgniteBiPredicate igniteBiPredicate3 = (IgniteBiPredicate) serializedLambda.getCapturedArg(0);
                    IgniteBiPredicate igniteBiPredicate4 = (IgniteBiPredicate) serializedLambda.getCapturedArg(1);
                    return (obj5, obj24) -> {
                        return igniteBiPredicate3.apply(obj5, obj24) && !igniteBiPredicate4.apply(obj5, obj24);
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/lang/IgniteBiPredicate") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Z") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/selection/cv/CrossValidation") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Z")) {
                    return (obj6, obj25) -> {
                        return true;
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/lang/IgniteBiPredicate") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Z") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/selection/cv/CrossValidation") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/ignite/lang/IgniteBiPredicate;Lorg/apache/ignite/lang/IgniteBiPredicate;Ljava/lang/Object;Ljava/lang/Object;)Z")) {
                    IgniteBiPredicate igniteBiPredicate5 = (IgniteBiPredicate) serializedLambda.getCapturedArg(0);
                    IgniteBiPredicate igniteBiPredicate6 = (IgniteBiPredicate) serializedLambda.getCapturedArg(1);
                    return (obj7, obj26) -> {
                        return igniteBiPredicate5.apply(obj7, obj26) && igniteBiPredicate6.apply(obj7, obj26);
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/ignite/lang/IgniteBiPredicate") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Z") && serializedLambda.getImplClass().equals("org/apache/ignite/ml/selection/cv/CrossValidation") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/ignite/lang/IgniteBiPredicate;Lorg/apache/ignite/lang/IgniteBiPredicate;Ljava/lang/Object;Ljava/lang/Object;)Z")) {
                    IgniteBiPredicate igniteBiPredicate7 = (IgniteBiPredicate) serializedLambda.getCapturedArg(0);
                    IgniteBiPredicate igniteBiPredicate8 = (IgniteBiPredicate) serializedLambda.getCapturedArg(1);
                    return (obj8, obj27) -> {
                        return igniteBiPredicate7.apply(obj8, obj27) && !igniteBiPredicate8.apply(obj8, obj27);
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
