/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.internal.processors.rest;

import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.OutputStream;
import java.net.InetAddress;
import java.net.Socket;
import java.nio.ByteBuffer;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicBoolean;
import org.apache.ignite.configuration.ConnectorConfiguration;
import org.apache.ignite.configuration.IgniteConfiguration;
import org.apache.ignite.internal.client.marshaller.jdk.GridClientJdkMarshaller;
import org.apache.ignite.internal.processors.rest.client.message.GridClientHandshakeRequest;
import org.apache.ignite.internal.processors.rest.client.message.GridClientMessage;
import org.apache.ignite.internal.util.IgniteUtils;
import org.apache.ignite.internal.util.lang.GridAbsPredicate;
import org.apache.ignite.internal.util.typedef.internal.U;
import org.apache.ignite.testframework.GridTestUtils;
import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest;
import org.junit.Test;

public class TcpRestUnmarshalVulnerabilityTest
extends GridCommonAbstractTest {
    private static final GridClientJdkMarshaller MARSH = new GridClientJdkMarshaller();
    private static final AtomicBoolean SHARED = new AtomicBoolean();
    private static int port;
    private static String host;

    protected IgniteConfiguration getConfiguration(String igniteInstanceName) throws Exception {
        IgniteConfiguration cfg = super.getConfiguration(igniteInstanceName);
        ConnectorConfiguration connCfg = new ConnectorConfiguration();
        port = connCfg.getPort();
        host = connCfg.getHost();
        cfg.setConnectorConfiguration(connCfg);
        return cfg;
    }

    protected void beforeTest() throws Exception {
        super.beforeTest();
        SHARED.set(false);
        System.clearProperty("IGNITE_MARSHALLER_WHITELIST");
        System.clearProperty("IGNITE_MARSHALLER_BLACKLIST");
        IgniteUtils.clearClassCache();
    }

    @Test
    public void testNoLists() throws Exception {
        this.testExploit(true);
    }

    @Test
    public void testWhiteListIncluded() throws Exception {
        String path = U.resolveIgnitePath((String)"modules/core/src/test/config/class_list_exploit_included.txt").getPath();
        System.setProperty("IGNITE_MARSHALLER_WHITELIST", path);
        this.testExploit(true);
    }

    @Test
    public void testWhiteListExcluded() throws Exception {
        String path = U.resolveIgnitePath((String)"modules/core/src/test/config/class_list_exploit_excluded.txt").getPath();
        System.setProperty("IGNITE_MARSHALLER_WHITELIST", path);
        this.testExploit(false);
    }

    @Test
    public void testBlackListIncluded() throws Exception {
        String path = U.resolveIgnitePath((String)"modules/core/src/test/config/class_list_exploit_included.txt").getPath();
        System.setProperty("IGNITE_MARSHALLER_BLACKLIST", path);
        this.testExploit(false);
    }

    @Test
    public void testBlackListExcluded() throws Exception {
        String path = U.resolveIgnitePath((String)"modules/core/src/test/config/class_list_exploit_excluded.txt").getPath();
        System.setProperty("IGNITE_MARSHALLER_BLACKLIST", path);
        this.testExploit(true);
    }

    @Test
    public void testBothListIncluded() throws Exception {
        String path = U.resolveIgnitePath((String)"modules/core/src/test/config/class_list_exploit_included.txt").getPath();
        System.setProperty("IGNITE_MARSHALLER_WHITELIST", path);
        System.setProperty("IGNITE_MARSHALLER_BLACKLIST", path);
        this.testExploit(false);
    }

    private void testExploit(boolean positive) throws Exception {
        try {
            this.startGrid();
            this.attack(TcpRestUnmarshalVulnerabilityTest.marshal(new Exploit()).array());
            boolean res = GridTestUtils.waitForCondition((GridAbsPredicate)new GridAbsPredicate(){

                public boolean apply() {
                    return SHARED.get();
                }
            }, (long)3000L);
            if (positive) {
                TcpRestUnmarshalVulnerabilityTest.assertTrue((boolean)res);
            } else {
                TcpRestUnmarshalVulnerabilityTest.assertFalse((boolean)res);
            }
        }
        finally {
            this.stopAllGrids();
        }
    }

    private static ByteBuffer marshal(Object obj) throws IOException {
        return MARSH.marshal(obj, 0);
    }

    private void attack(byte[] data) throws IOException {
        InetAddress addr = InetAddress.getByName(host);
        try (Socket sock = new Socket(addr, port);
             BufferedOutputStream os = new BufferedOutputStream(sock.getOutputStream());){
            ((OutputStream)os).write(-111);
            GridClientHandshakeRequest req = new GridClientHandshakeRequest();
            req.marshallerId((byte)2);
            ((OutputStream)os).write(req.rawBytes());
            ((OutputStream)os).flush();
            BufferedInputStream is = new BufferedInputStream(sock.getInputStream());
            ((InputStream)is).read(new byte[146]);
            int len = data.length + 40;
            ((OutputStream)os).write(-112);
            ((OutputStream)os).write((byte)(len >> 24));
            ((OutputStream)os).write((byte)(len >> 16));
            ((OutputStream)os).write((byte)(len >> 8));
            ((OutputStream)os).write((byte)len);
            ((OutputStream)os).write(new byte[40]);
            ((OutputStream)os).write(data);
            ((OutputStream)os).flush();
        }
    }

    private static class Exploit
    implements GridClientMessage {
        private Exploit() {
        }

        private void readObject(ObjectInputStream is) throws ClassNotFoundException, IOException {
            SHARED.set(true);
        }

        public long requestId() {
            return 0L;
        }

        public void requestId(long reqId) {
        }

        public UUID clientId() {
            return null;
        }

        public void clientId(UUID id) {
        }

        public UUID destinationId() {
            return null;
        }

        public void destinationId(UUID id) {
        }

        public byte[] sessionToken() {
            return new byte[0];
        }

        public void sessionToken(byte[] sesTok) {
        }
    }
}

