package org.apache.ignite.ml.dataset.feature.extractor;

import java.io.Serializable;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.ignite.internal.util.typedef.internal.A;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
import org.apache.ignite.ml.structures.LabeledVector;
import org.apache.ignite.ml.trainers.FeatureLabelExtractor;

/* loaded from: input_file:org/apache/ignite/ml/dataset/feature/extractor/Vectorizer.class */
public abstract class Vectorizer<K, V, C extends Serializable, L> implements FeatureLabelExtractor<K, V, L> {
    private static final long serialVersionUID = 4301406952131379459L;
    private final boolean useAllValues;
    private List<C> extractionCoordinates;
    private C labelCoord;
    private LabelCoordinate lbCoordinateShortcut = null;
    private HashSet<C> excludedCoords = new HashSet<>();

    /* loaded from: input_file:org/apache/ignite/ml/dataset/feature/extractor/Vectorizer$LabelCoordinate.class */
    public enum LabelCoordinate {
        FIRST,
        LAST
    }

    /* loaded from: input_file:org/apache/ignite/ml/dataset/feature/extractor/Vectorizer$VectorizerAdapter.class */
    public static abstract class VectorizerAdapter<K, V, C extends Serializable, L> extends Vectorizer<K, V, C, L> {
        public VectorizerAdapter() {
            super(new Serializable[0]);
        }

        @Override // org.apache.ignite.ml.dataset.feature.extractor.Vectorizer
        protected Double feature(C c, K k, V v) {
            throw new IllegalStateException();
        }

        @Override // org.apache.ignite.ml.dataset.feature.extractor.Vectorizer
        protected L label(C c, K k, V v) {
            throw new IllegalStateException();
        }

        @Override // org.apache.ignite.ml.dataset.feature.extractor.Vectorizer
        protected L zero() {
            throw new IllegalStateException();
        }

        @Override // org.apache.ignite.ml.dataset.feature.extractor.Vectorizer
        protected List<C> allCoords(K k, V v) {
            throw new IllegalStateException();
        }

        /* JADX WARN: Multi-variable type inference failed */
        @Override // org.apache.ignite.ml.dataset.feature.extractor.Vectorizer
        protected /* bridge */ /* synthetic */ Serializable feature(Serializable serializable, Object obj, Object obj2) {
            return feature((VectorizerAdapter<K, V, C, L>) serializable, (Serializable) obj, obj2);
        }

        /* JADX WARN: Multi-variable type inference failed */
        @Override // org.apache.ignite.ml.dataset.feature.extractor.Vectorizer, java.util.function.BiFunction
        public /* bridge */ /* synthetic */ Object apply(Object obj, Object obj2) {
            return super.apply((VectorizerAdapter<K, V, C, L>) obj, obj2);
        }
    }

    @Override // java.util.function.BiFunction
    public LabeledVector<L> apply(K k, V v) {
        L label = isLabeled() ? label(labelCoord(k, v), k, v) : zero();
        List<C> list = this.useAllValues ? (List) allCoords(k, v).stream().filter(serializable -> {
            return (serializable.equals(this.labelCoord) || this.excludedCoords.contains(serializable)) ? false : true;
        }).collect(Collectors.toList()) : null;
        int size = this.useAllValues ? list.size() : this.extractionCoordinates.size();
        A.ensure(size >= 0, "vectorLength >= 0");
        List<C> list2 = this.useAllValues ? list : this.extractionCoordinates;
        Vector createVector = createVector(size);
        for (int i = 0; i < list2.size(); i++) {
            Serializable feature = feature(list2.get(i), k, v);
            if (feature != null) {
                createVector.setRaw(i, feature);
            }
        }
        return new LabeledVector<>(createVector, label);
    }

    public Vectorizer(C... cArr) {
        this.extractionCoordinates = Arrays.asList(cArr);
        this.useAllValues = cArr.length == 0;
    }

    private boolean isLabeled() {
        return (this.labelCoord == null && this.lbCoordinateShortcut == null) ? false : true;
    }

    private C labelCoord(K k, V v) {
        A.ensure(isLabeled(), "isLabeled");
        if (this.labelCoord != null) {
            return this.labelCoord;
        }
        List<C> allCoords = allCoords(k, v);
        A.ensure(!allCoords.isEmpty(), "!allCoords.isEmpty()");
        switch (this.lbCoordinateShortcut) {
            case FIRST:
                this.labelCoord = allCoords.get(0);
                break;
            case LAST:
                this.labelCoord = allCoords.get(allCoords.size() - 1);
                break;
            default:
                throw new IllegalArgumentException();
        }
        return this.labelCoord;
    }

    public Vectorizer<K, V, C, L> labeled(C c) {
        this.labelCoord = c;
        this.lbCoordinateShortcut = null;
        return this;
    }

    public Vectorizer<K, V, C, L> labeled(LabelCoordinate labelCoordinate) {
        this.lbCoordinateShortcut = labelCoordinate;
        this.labelCoord = null;
        return this;
    }

    public Vectorizer<K, V, C, L> exclude(C... cArr) {
        this.excludedCoords.addAll(Arrays.asList(cArr));
        return this;
    }

    @Override // org.apache.ignite.ml.trainers.FeatureLabelExtractor
    public LabeledVector<L> extract(K k, V v) {
        return apply((Vectorizer<K, V, C, L>) k, (K) v);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public abstract Serializable feature(C c, K k, V v);

    protected abstract L label(C c, K k, V v);

    protected abstract L zero();

    protected abstract List<C> allCoords(K k, V v);

    protected Vector createVector(int i) {
        return new DenseVector(i);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // java.util.function.BiFunction
    public /* bridge */ /* synthetic */ Object apply(Object obj, Object obj2) {
        return apply((Vectorizer<K, V, C, L>) obj, obj2);
    }
}
