package de.rub.nds.tlsattacker.core.protocol.handler;

import de.rub.nds.modifiablevariable.util.ArrayConverter;
import de.rub.nds.tlsattacker.core.constants.AlgorithmResolver;
import de.rub.nds.tlsattacker.core.constants.CipherSuite;
import de.rub.nds.tlsattacker.core.constants.CompressionMethod;
import de.rub.nds.tlsattacker.core.constants.DigestAlgorithm;
import de.rub.nds.tlsattacker.core.constants.ExtensionType;
import de.rub.nds.tlsattacker.core.constants.HKDFAlgorithm;
import de.rub.nds.tlsattacker.core.constants.ProtocolVersion;
import de.rub.nds.tlsattacker.core.constants.Tls13KeySetType;
import de.rub.nds.tlsattacker.core.crypto.HKDFunction;
import de.rub.nds.tlsattacker.core.exceptions.AdjustmentException;
import de.rub.nds.tlsattacker.core.exceptions.CryptoException;
import de.rub.nds.tlsattacker.core.exceptions.WorkflowExecutionException;
import de.rub.nds.tlsattacker.core.protocol.message.ClientHelloMessage;
import de.rub.nds.tlsattacker.core.protocol.parser.ClientHelloParser;
import de.rub.nds.tlsattacker.core.protocol.preparator.ClientHelloPreparator;
import de.rub.nds.tlsattacker.core.protocol.serializer.ClientHelloSerializer;
import de.rub.nds.tlsattacker.core.record.cipher.RecordCipherFactory;
import de.rub.nds.tlsattacker.core.record.cipher.cryptohelper.KeySetGenerator;
import de.rub.nds.tlsattacker.core.state.TlsContext;
import de.rub.nds.tlsattacker.transport.ConnectionEndType;
import java.security.NoSuchAlgorithmException;
import java.util.LinkedList;
import java.util.List;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

/* loaded from: input_file:de/rub/nds/tlsattacker/core/protocol/handler/ClientHelloHandler.class */
public class ClientHelloHandler extends HandshakeMessageHandler<ClientHelloMessage> {
    private static final Logger LOGGER = LogManager.getLogger();

    public ClientHelloHandler(TlsContext tlsContext) {
        super(tlsContext);
    }

    @Override // de.rub.nds.tlsattacker.core.protocol.handler.HandshakeMessageHandler, de.rub.nds.tlsattacker.core.protocol.handler.TlsMessageHandler, de.rub.nds.tlsattacker.core.protocol.ProtocolMessageHandler, de.rub.nds.tlsattacker.core.protocol.Handler
    public ClientHelloParser getParser(byte[] bArr, int i) {
        return new ClientHelloParser(i, bArr, this.tlsContext.getChooser().getLastRecordVersion(), this.tlsContext.getConfig());
    }

    @Override // de.rub.nds.tlsattacker.core.protocol.handler.HandshakeMessageHandler, de.rub.nds.tlsattacker.core.protocol.handler.TlsMessageHandler, de.rub.nds.tlsattacker.core.protocol.ProtocolMessageHandler, de.rub.nds.tlsattacker.core.protocol.Handler
    public ClientHelloPreparator getPreparator(ClientHelloMessage clientHelloMessage) {
        return new ClientHelloPreparator(this.tlsContext.getChooser(), clientHelloMessage);
    }

    @Override // de.rub.nds.tlsattacker.core.protocol.handler.HandshakeMessageHandler, de.rub.nds.tlsattacker.core.protocol.handler.TlsMessageHandler, de.rub.nds.tlsattacker.core.protocol.ProtocolMessageHandler, de.rub.nds.tlsattacker.core.protocol.Handler
    public ClientHelloSerializer getSerializer(ClientHelloMessage clientHelloMessage) {
        return new ClientHelloSerializer(clientHelloMessage, this.tlsContext.getChooser().getSelectedProtocolVersion());
    }

    @Override // de.rub.nds.tlsattacker.core.protocol.handler.TlsMessageHandler
    public void adjustTLSContext(ClientHelloMessage clientHelloMessage) {
        adjustProtocolVersion(clientHelloMessage);
        adjustSessionID(clientHelloMessage);
        adjustClientSupportedCipherSuites(clientHelloMessage);
        adjustClientSupportedCompressions(clientHelloMessage);
        if (isCookieFieldSet(clientHelloMessage)) {
            adjustDTLSCookie(clientHelloMessage);
        }
        adjustExtensions(clientHelloMessage);
        adjustRandomContext(clientHelloMessage);
        if (this.tlsContext.getChooser().getSelectedProtocolVersion().isTLS13() && this.tlsContext.isExtensionNegotiated(ExtensionType.EARLY_DATA)) {
            try {
                adjustEarlyTrafficSecret();
                setClientRecordCipherEarly();
            } catch (CryptoException e) {
                throw new AdjustmentException("Could not adjust", e);
            }
        }
        this.tlsContext.setLastClientHello((byte[]) clientHelloMessage.getCompleteResultingMessage().getValue());
    }

    private boolean isCookieFieldSet(ClientHelloMessage clientHelloMessage) {
        return clientHelloMessage.getCookie() != null;
    }

    private void adjustClientSupportedCipherSuites(ClientHelloMessage clientHelloMessage) {
        List<CipherSuite> convertCipherSuites = convertCipherSuites((byte[]) clientHelloMessage.getCipherSuites().getValue());
        this.tlsContext.setClientSupportedCipherSuites(convertCipherSuites);
        if (convertCipherSuites != null) {
            LOGGER.debug("Set ClientSupportedCipherSuites in Context to " + convertCipherSuites.toString());
        } else {
            LOGGER.debug("Set ClientSupportedCipherSuites in Context to " + ((Object) null));
        }
    }

    private void adjustClientSupportedCompressions(ClientHelloMessage clientHelloMessage) {
        List<CompressionMethod> convertCompressionMethods = convertCompressionMethods((byte[]) clientHelloMessage.getCompressions().getValue());
        this.tlsContext.setClientSupportedCompressions(convertCompressionMethods);
        LOGGER.debug("Set ClientSupportedCompressions in Context to " + convertCompressionMethods.toString());
    }

    private void adjustDTLSCookie(ClientHelloMessage clientHelloMessage) {
        byte[] bArr = (byte[]) clientHelloMessage.getCookie().getValue();
        this.tlsContext.setDtlsCookie(bArr);
        LOGGER.debug("Set DTLS Cookie in Context to " + ArrayConverter.bytesToHexString(bArr));
    }

    private void adjustSessionID(ClientHelloMessage clientHelloMessage) {
        byte[] bArr = (byte[]) clientHelloMessage.getSessionId().getValue();
        this.tlsContext.setClientSessionId(bArr);
        LOGGER.debug("Set SessionId in Context to " + ArrayConverter.bytesToHexString(bArr, false));
    }

    private void adjustProtocolVersion(ClientHelloMessage clientHelloMessage) {
        ProtocolVersion protocolVersion = ProtocolVersion.getProtocolVersion((byte[]) clientHelloMessage.getProtocolVersion().getValue());
        if (protocolVersion == null) {
            LOGGER.warn("Did not Adjust ProtocolVersion since version is undefined " + ArrayConverter.bytesToHexString((byte[]) clientHelloMessage.getProtocolVersion().getValue()));
        } else {
            this.tlsContext.setHighestClientProtocolVersion(protocolVersion);
            LOGGER.debug("Set HighestClientProtocolVersion in Context to " + protocolVersion.name());
        }
    }

    private void adjustRandomContext(ClientHelloMessage clientHelloMessage) {
        this.tlsContext.setClientRandom((byte[]) clientHelloMessage.getRandom().getValue());
        LOGGER.debug("Set ClientRandom in Context to " + ArrayConverter.bytesToHexString(this.tlsContext.getClientRandom()));
    }

    private List<CompressionMethod> convertCompressionMethods(byte[] bArr) {
        LinkedList linkedList = new LinkedList();
        for (byte b : bArr) {
            CompressionMethod compressionMethod = CompressionMethod.getCompressionMethod(b);
            if (compressionMethod == null) {
                LOGGER.warn("Could not convert " + ((int) b) + " into a CompressionMethod");
            } else {
                linkedList.add(compressionMethod);
            }
        }
        return linkedList;
    }

    private List<CipherSuite> convertCipherSuites(byte[] bArr) {
        if (bArr.length % 2 != 0) {
            LOGGER.warn("Cannot convert:" + ArrayConverter.bytesToHexString(bArr, false) + " to a List<CipherSuite>");
            return null;
        }
        LinkedList linkedList = new LinkedList();
        for (int i = 0; i < bArr.length; i += 2) {
            byte[] bArr2 = {bArr[i], bArr[i + 1]};
            CipherSuite cipherSuite = CipherSuite.getCipherSuite(bArr2);
            if (cipherSuite == null) {
                LOGGER.warn("Cannot convert:" + ArrayConverter.bytesToHexString(bArr2) + " to a CipherSuite");
            } else {
                linkedList.add(cipherSuite);
            }
        }
        return linkedList;
    }

    @Override // de.rub.nds.tlsattacker.core.protocol.handler.TlsMessageHandler
    public void adjustTlsContextAfterSerialize(ClientHelloMessage clientHelloMessage) {
        if (this.tlsContext.getChooser().getConnectionEndType() == ConnectionEndType.CLIENT && this.tlsContext.isExtensionProposed(ExtensionType.EARLY_DATA)) {
            try {
                adjustEarlyTrafficSecret();
                setClientRecordCipherEarly();
            } catch (CryptoException e) {
                LOGGER.warn("Encountered an exception in adjust after Serialize", e);
            }
        }
    }

    private void adjustEarlyTrafficSecret() throws CryptoException {
        HKDFAlgorithm hKDFAlgorithm = AlgorithmResolver.getHKDFAlgorithm(this.tlsContext.getChooser().getEarlyDataCipherSuite());
        DigestAlgorithm digestAlgorithm = AlgorithmResolver.getDigestAlgorithm(ProtocolVersion.TLS13, this.tlsContext.getChooser().getEarlyDataCipherSuite());
        this.tlsContext.setEarlySecret(HKDFunction.extract(hKDFAlgorithm, new byte[0], this.tlsContext.getChooser().getEarlyDataPsk()));
        byte[] deriveSecret = HKDFunction.deriveSecret(hKDFAlgorithm, digestAlgorithm.getJavaName(), this.tlsContext.getChooser().getEarlySecret(), HKDFunction.CLIENT_EARLY_TRAFFIC_SECRET, this.tlsContext.getDigest().getRawBytes());
        this.tlsContext.setClientEarlyTrafficSecret(deriveSecret);
        LOGGER.debug("EarlyTrafficSecret: " + ArrayConverter.bytesToHexString(deriveSecret));
    }

    private void setClientRecordCipherEarly() throws CryptoException {
        try {
            this.tlsContext.setActiveClientKeySetType(Tls13KeySetType.EARLY_TRAFFIC_SECRETS);
            LOGGER.debug("Setting cipher for client to use early secrets");
            this.tlsContext.getRecordLayer().setRecordCipher(RecordCipherFactory.getRecordCipher(this.tlsContext, KeySetGenerator.generateKeySet(this.tlsContext, ProtocolVersion.TLS13, this.tlsContext.getActiveClientKeySetType()), this.tlsContext.getChooser().getEarlyDataCipherSuite()));
            if (this.tlsContext.getChooser().getConnectionEndType() == ConnectionEndType.SERVER) {
                this.tlsContext.setReadSequenceNumber(0L);
                this.tlsContext.getRecordLayer().updateDecryptionCipher();
            } else {
                this.tlsContext.setWriteSequenceNumber(0L);
                this.tlsContext.getRecordLayer().updateEncryptionCipher();
            }
        } catch (NoSuchAlgorithmException e) {
            LOGGER.error("Unable to generate KeySet - unknown algorithm");
            throw new WorkflowExecutionException(e.toString());
        }
    }
}
