/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.internal.managers.encryption;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.stream.Collectors;
import org.apache.ignite.IgniteLogger;
import org.apache.ignite.internal.managers.encryption.GroupKey;
import org.apache.ignite.internal.managers.encryption.GroupKeyEncrypted;
import org.apache.ignite.internal.util.typedef.F;
import org.apache.ignite.internal.util.typedef.internal.U;
import org.apache.ignite.lang.IgnitePredicate;
import org.apache.ignite.spi.encryption.EncryptionSpi;
import org.jetbrains.annotations.Nullable;

class CacheGroupEncryptionKeys {
    private final Map<Integer, List<GroupKey>> grpKeys = new ConcurrentHashMap<Integer, List<GroupKey>>();
    private final Collection<TrackedWalSegment> trackedWalSegments = new ConcurrentLinkedQueue<TrackedWalSegment>();
    private final EncryptionSpi encSpi;
    private final IgniteLogger log;

    CacheGroupEncryptionKeys(EncryptionSpi encSpi, IgniteLogger log) {
        this.encSpi = encSpi;
        this.log = log;
    }

    @Nullable
    GroupKey getActiveKey(int grpId) {
        List<GroupKey> keys = this.grpKeys.get(grpId);
        if (F.isEmpty(keys)) {
            return null;
        }
        GroupKey key = keys.get(0);
        assert (key != null) : "grpId=" + grpId;
        return key;
    }

    @Nullable
    GroupKey getKey(int grpId, int keyId) {
        List<GroupKey> keys = this.grpKeys.get(grpId);
        if (F.isEmpty(keys)) {
            return null;
        }
        for (GroupKey groupKey : keys) {
            if (groupKey.unsignedId() != keyId) continue;
            return groupKey;
        }
        if (this.log.isDebugEnabled()) {
            this.log.debug("No keys matching specified keyId=" + keyId + " was found");
        }
        return null;
    }

    @Nullable
    List<Integer> keyIds(int grpId) {
        List<GroupKey> keys = this.grpKeys.get(grpId);
        if (keys == null) {
            return null;
        }
        ArrayList<Integer> keyIds = new ArrayList<Integer>(keys.size());
        for (GroupKey groupKey : keys) {
            keyIds.add(groupKey.unsignedId());
        }
        return keyIds;
    }

    Set<Integer> groupIds() {
        return this.grpKeys.keySet();
    }

    @Nullable
    HashMap<Integer, GroupKeyEncrypted> getAll() {
        if (F.isEmpty(this.grpKeys)) {
            return null;
        }
        HashMap<Integer, GroupKeyEncrypted> keys = U.newHashMap(this.grpKeys.size());
        for (Map.Entry<Integer, List<GroupKey>> entry : this.grpKeys.entrySet()) {
            int grpId = entry.getKey();
            GroupKey grpKey = entry.getValue().get(0);
            keys.put(grpId, new GroupKeyEncrypted(grpKey.unsignedId(), this.encSpi.encryptKey(grpKey.key())));
        }
        return keys;
    }

    @Nullable
    List<GroupKeyEncrypted> getAll(int grpId) {
        List<GroupKey> grpKeys = this.grpKeys.get(grpId);
        if (F.isEmpty(grpKeys)) {
            return null;
        }
        ArrayList<GroupKeyEncrypted> encryptedKeys = new ArrayList<GroupKeyEncrypted>(grpKeys.size());
        for (GroupKey grpKey : grpKeys) {
            encryptedKeys.add(new GroupKeyEncrypted(grpKey.unsignedId(), this.encSpi.encryptKey(grpKey.key())));
        }
        return encryptedKeys;
    }

    GroupKey changeActiveKey(int grpId, int keyId) {
        if (this.log.isDebugEnabled()) {
            this.log.debug("Change active encryption key [grpId=" + grpId + ", keyId=" + keyId + ']');
        }
        List<GroupKey> keys = this.grpKeys.get(grpId);
        assert (!F.isEmpty(keys)) : "grpId=" + grpId;
        GroupKey prevKey = keys.get(0);
        assert (prevKey.unsignedId() != keyId) : "keyId=" + keyId;
        GroupKey newKey = null;
        ListIterator<GroupKey> itr = keys.listIterator(keys.size());
        while (itr.hasPrevious()) {
            GroupKey key = itr.previous();
            if (key.unsignedId() != keyId) continue;
            newKey = key;
            break;
        }
        assert (newKey != null) : "exp=" + keyId + ", act=" + keys;
        keys.add(0, newKey);
        keys.subList(1, keys.size()).removeIf(k -> k.unsignedId() == keyId);
        return prevKey;
    }

    boolean addKey(int grpId, GroupKeyEncrypted newEncKey) {
        GroupKey grpKey;
        List keys;
        if (this.log.isDebugEnabled()) {
            this.log.debug("Add new encryption key [grpId=" + grpId + ", keyId=" + newEncKey.id() + ']');
        }
        if (!(keys = this.grpKeys.computeIfAbsent(grpId, v -> new CopyOnWriteArrayList())).contains(grpKey = new GroupKey(newEncKey.id(), this.encSpi.decryptKey(newEncKey.key())))) {
            return keys.add(grpKey);
        }
        return false;
    }

    void setGroupKeys(int grpId, List<GroupKeyEncrypted> encryptedKeys) {
        if (this.log.isDebugEnabled()) {
            this.log.debug("Set new encryption key(s) [grpId=" + grpId + ", keys=" + encryptedKeys.stream().map(GroupKeyEncrypted::id).collect(Collectors.toList()) + ']');
        }
        CopyOnWriteArrayList<GroupKey> keys = new CopyOnWriteArrayList<GroupKey>();
        for (GroupKeyEncrypted grpKey : encryptedKeys) {
            keys.add(new GroupKey(grpKey.id(), this.encSpi.decryptKey(grpKey.key())));
        }
        this.grpKeys.put(grpId, keys);
    }

    List<GroupKey> remove(int grpId) {
        return this.grpKeys.remove(grpId);
    }

    boolean removeKeysById(int grpId, Set<Integer> ids) {
        List<GroupKey> keys = this.grpKeys.get(grpId);
        if (F.isEmpty(keys)) {
            return false;
        }
        return keys.subList(1, keys.size()).removeIf(key -> ids.contains(key.unsignedId()));
    }

    Set<Integer> removeUnusedKeys(int grpId) {
        List<GroupKey> keys = this.grpKeys.get(grpId);
        HashSet<Integer> rmvKeyIds = U.newHashSet(keys.size() - 1);
        rmvKeyIds.addAll(F.viewReadOnly(keys.subList(1, keys.size()), GroupKey::unsignedId, new IgnitePredicate[0]));
        for (TrackedWalSegment segment : this.trackedWalSegments) {
            if (segment.grpId != grpId) continue;
            rmvKeyIds.remove(segment.keyId);
        }
        if (keys.removeIf(key -> rmvKeyIds.contains(key.unsignedId()))) {
            return rmvKeyIds;
        }
        return Collections.emptySet();
    }

    Collection<TrackedWalSegment> trackedWalSegments() {
        return Collections.unmodifiableCollection(this.trackedWalSegments);
    }

    void trackedWalSegments(Collection<TrackedWalSegment> segments) {
        if (this.log.isDebugEnabled()) {
            this.log.debug("Reserve WAL encryption keys, segments=[" + segments.stream().map(s2 -> "[grpId=" + ((TrackedWalSegment)s2).grpId + ", keyId=" + ((TrackedWalSegment)s2).keyId + ", walIdx=" + ((TrackedWalSegment)s2).idx + "]").collect(Collectors.joining(", ")) + "]");
        }
        this.trackedWalSegments.addAll(segments);
    }

    void reserveWalKey(int grpId, int keyId, long walIdx) {
        if (this.log.isDebugEnabled()) {
            this.log.debug("Reserve WAL encryption key [grpId=" + grpId + ", keyId=" + keyId + ", walIdx=" + walIdx + "]");
        }
        this.trackedWalSegments.add(new TrackedWalSegment(walIdx, grpId, keyId));
    }

    @Nullable
    Long reservedSegment(int grpId, int keyId) {
        for (TrackedWalSegment segment : this.trackedWalSegments) {
            if (segment.grpId != grpId || segment.keyId != keyId) continue;
            return segment.idx;
        }
        return null;
    }

    boolean isReleaseWalKeysRequired(long walIdx) {
        Iterator<TrackedWalSegment> iter = this.trackedWalSegments.iterator();
        return iter.hasNext() && iter.next().idx <= walIdx;
    }

    Map<Integer, Set<Integer>> releaseWalKeys(long walIdx) {
        TrackedWalSegment segment;
        HashMap<Integer, Set<Integer>> rmvKeys = new HashMap<Integer, Set<Integer>>();
        Iterator<TrackedWalSegment> iter = this.trackedWalSegments.iterator();
        while (iter.hasNext() && (segment = iter.next()).idx <= walIdx) {
            iter.remove();
            rmvKeys.computeIfAbsent(segment.grpId, v -> new HashSet()).add(segment.keyId);
        }
        return rmvKeys;
    }

    protected static final class TrackedWalSegment
    implements Serializable {
        private static final long serialVersionUID = 0L;
        private final long idx;
        private final int grpId;
        private final int keyId;

        public TrackedWalSegment(long idx, int grpId, int keyId) {
            this.idx = idx;
            this.grpId = grpId;
            this.keyId = keyId;
        }
    }
}

