package de.rub.nds.tlsattacker.core.protocol.handler;

import de.rub.nds.modifiablevariable.util.ArrayConverter;
import de.rub.nds.tlsattacker.core.constants.AlgorithmResolver;
import de.rub.nds.tlsattacker.core.constants.CipherSuite;
import de.rub.nds.tlsattacker.core.constants.CompressionMethod;
import de.rub.nds.tlsattacker.core.constants.DigestAlgorithm;
import de.rub.nds.tlsattacker.core.constants.ExtensionType;
import de.rub.nds.tlsattacker.core.constants.HKDFAlgorithm;
import de.rub.nds.tlsattacker.core.constants.HandshakeMessageType;
import de.rub.nds.tlsattacker.core.constants.ProtocolVersion;
import de.rub.nds.tlsattacker.core.constants.Tls13KeySetType;
import de.rub.nds.tlsattacker.core.crypto.HKDFunction;
import de.rub.nds.tlsattacker.core.crypto.KeyShareCalculator;
import de.rub.nds.tlsattacker.core.crypto.ec.CurveFactory;
import de.rub.nds.tlsattacker.core.crypto.ec.EllipticCurve;
import de.rub.nds.tlsattacker.core.crypto.ec.Point;
import de.rub.nds.tlsattacker.core.crypto.ec.PointFormatter;
import de.rub.nds.tlsattacker.core.exceptions.AdjustmentException;
import de.rub.nds.tlsattacker.core.exceptions.CryptoException;
import de.rub.nds.tlsattacker.core.protocol.message.ServerHelloMessage;
import de.rub.nds.tlsattacker.core.protocol.message.computations.PWDComputations;
import de.rub.nds.tlsattacker.core.protocol.message.extension.keyshare.DragonFlyKeyShareEntry;
import de.rub.nds.tlsattacker.core.protocol.message.extension.keyshare.KeyShareStoreEntry;
import de.rub.nds.tlsattacker.core.protocol.parser.ServerHelloParser;
import de.rub.nds.tlsattacker.core.protocol.parser.extension.keyshare.DragonFlyKeyShareEntryParser;
import de.rub.nds.tlsattacker.core.protocol.preparator.ServerHelloPreparator;
import de.rub.nds.tlsattacker.core.protocol.serializer.ServerHelloSerializer;
import de.rub.nds.tlsattacker.core.record.cipher.RecordCipherFactory;
import de.rub.nds.tlsattacker.core.record.cipher.cryptohelper.KeySet;
import de.rub.nds.tlsattacker.core.record.cipher.cryptohelper.KeySetGenerator;
import de.rub.nds.tlsattacker.core.state.TlsContext;
import de.rub.nds.tlsattacker.core.workflow.chooser.Chooser;
import de.rub.nds.tlsattacker.transport.ConnectionEndType;
import java.math.BigInteger;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.Arrays;
import javax.crypto.Mac;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

/* loaded from: input_file:de/rub/nds/tlsattacker/core/protocol/handler/ServerHelloHandler.class */
public class ServerHelloHandler extends HandshakeMessageHandler<ServerHelloMessage> {
    private static final Logger LOGGER = LogManager.getLogger();

    public ServerHelloHandler(TlsContext tlsContext) {
        super(tlsContext);
    }

    @Override // de.rub.nds.tlsattacker.core.protocol.handler.HandshakeMessageHandler, de.rub.nds.tlsattacker.core.protocol.handler.TlsMessageHandler, de.rub.nds.tlsattacker.core.protocol.ProtocolMessageHandler, de.rub.nds.tlsattacker.core.protocol.Handler
    public ServerHelloPreparator getPreparator(ServerHelloMessage serverHelloMessage) {
        return new ServerHelloPreparator(this.tlsContext.getChooser(), serverHelloMessage);
    }

    @Override // de.rub.nds.tlsattacker.core.protocol.handler.HandshakeMessageHandler, de.rub.nds.tlsattacker.core.protocol.handler.TlsMessageHandler, de.rub.nds.tlsattacker.core.protocol.ProtocolMessageHandler, de.rub.nds.tlsattacker.core.protocol.Handler
    public ServerHelloSerializer getSerializer(ServerHelloMessage serverHelloMessage) {
        return new ServerHelloSerializer(serverHelloMessage, this.tlsContext.getChooser().getSelectedProtocolVersion());
    }

    @Override // de.rub.nds.tlsattacker.core.protocol.handler.HandshakeMessageHandler, de.rub.nds.tlsattacker.core.protocol.handler.TlsMessageHandler, de.rub.nds.tlsattacker.core.protocol.ProtocolMessageHandler, de.rub.nds.tlsattacker.core.protocol.Handler
    public ServerHelloParser getParser(byte[] bArr, int i) {
        return new ServerHelloParser(i, bArr, this.tlsContext.getChooser().getLastRecordVersion(), this.tlsContext.getConfig());
    }

    @Override // de.rub.nds.tlsattacker.core.protocol.handler.TlsMessageHandler
    public void adjustTLSContext(ServerHelloMessage serverHelloMessage) {
        adjustSelectedProtocolVersion(serverHelloMessage);
        adjustSelectedCompression(serverHelloMessage);
        adjustSelectedSessionID(serverHelloMessage);
        adjustSelectedCipherSuite(serverHelloMessage);
        adjustServerRandom(serverHelloMessage);
        adjustExtensions(serverHelloMessage);
        warnOnConflictingExtensions();
        if (serverHelloMessage.isTls13HelloRetryRequest().booleanValue()) {
            adjustHelloRetryDigest(serverHelloMessage);
            return;
        }
        if (this.tlsContext.getChooser().getSelectedProtocolVersion().isTLS13()) {
            adjustHandshakeTrafficSecrets(adjustKeyShareStoreEntry());
            if (this.tlsContext.getTalkingConnectionEndType() != this.tlsContext.getChooser().getConnectionEndType()) {
                setServerRecordCipher();
            }
        }
        adjustPRF(serverHelloMessage);
        if (this.tlsContext.hasSession(this.tlsContext.getChooser().getServerSessionId())) {
            LOGGER.info("Resuming Session");
            LOGGER.debug("Loading MasterSecret");
            this.tlsContext.setMasterSecret(this.tlsContext.getIdSession(this.tlsContext.getChooser().getServerSessionId()).getMasterSecret());
        }
    }

    private void adjustSelectedCipherSuite(ServerHelloMessage serverHelloMessage) {
        CipherSuite cipherSuite = null;
        if (serverHelloMessage.getSelectedCipherSuite() != null) {
            cipherSuite = CipherSuite.getCipherSuite((byte[]) serverHelloMessage.getSelectedCipherSuite().getValue());
        }
        if (cipherSuite == null) {
            LOGGER.warn("Unknown CipherSuite, did not adjust Context");
        } else {
            this.tlsContext.setSelectedCipherSuite(cipherSuite);
            LOGGER.debug("Set SelectedCipherSuite in Context to " + cipherSuite.name());
        }
    }

    private void adjustServerRandom(ServerHelloMessage serverHelloMessage) {
        this.tlsContext.setServerRandom((byte[]) serverHelloMessage.getRandom().getValue());
        LOGGER.debug("Set ServerRandom in Context to " + ArrayConverter.bytesToHexString(this.tlsContext.getServerRandom()));
    }

    private void adjustSelectedCompression(ServerHelloMessage serverHelloMessage) {
        CompressionMethod compressionMethod = null;
        if (serverHelloMessage.getSelectedCompressionMethod() != null) {
            compressionMethod = CompressionMethod.getCompressionMethod(((Byte) serverHelloMessage.getSelectedCompressionMethod().getValue()).byteValue());
        }
        if (compressionMethod == null) {
            LOGGER.warn("Not adjusting CompressionMethod - Method is null!");
        } else {
            this.tlsContext.setSelectedCompressionMethod(compressionMethod);
            LOGGER.debug("Set SelectedCompressionMethod in Context to " + compressionMethod.name());
        }
    }

    private void adjustSelectedSessionID(ServerHelloMessage serverHelloMessage) {
        byte[] bArr = (byte[]) serverHelloMessage.getSessionId().getValue();
        this.tlsContext.setServerSessionId(bArr);
        LOGGER.debug("Set SessionID in Context to " + ArrayConverter.bytesToHexString(bArr, false));
    }

    private void adjustSelectedProtocolVersion(ServerHelloMessage serverHelloMessage) {
        ProtocolVersion protocolVersion = null;
        if (serverHelloMessage.getProtocolVersion() != null) {
            protocolVersion = ProtocolVersion.getProtocolVersion((byte[]) serverHelloMessage.getProtocolVersion().getValue());
        }
        if (protocolVersion == null) {
            LOGGER.warn("Did not Adjust ProtocolVersion since version is undefined " + ArrayConverter.bytesToHexString((byte[]) serverHelloMessage.getProtocolVersion().getValue()));
        } else {
            this.tlsContext.setSelectedProtocolVersion(protocolVersion);
            LOGGER.debug("Set SelectedProtocolVersion in Context to " + protocolVersion.name());
        }
    }

    private void adjustPRF(ServerHelloMessage serverHelloMessage) {
        Chooser chooser = this.tlsContext.getChooser();
        if (chooser.getSelectedProtocolVersion().isSSL()) {
            return;
        }
        this.tlsContext.setPrfAlgorithm(AlgorithmResolver.getPRFAlgorithm(chooser.getSelectedProtocolVersion(), chooser.getSelectedCipherSuite()));
    }

    private void setServerRecordCipher() {
        this.tlsContext.setActiveServerKeySetType(Tls13KeySetType.HANDSHAKE_TRAFFIC_SECRETS);
        LOGGER.debug("Setting cipher for server to use handshake secrets");
        KeySet tls13KeySet = getTls13KeySet(this.tlsContext, this.tlsContext.getActiveServerKeySetType());
        if (this.tlsContext.getChooser().getConnectionEndType() == ConnectionEndType.CLIENT) {
            this.tlsContext.getRecordLayer().updateDecryptionCipher(RecordCipherFactory.getRecordCipher(this.tlsContext, tls13KeySet));
        } else {
            this.tlsContext.getRecordLayer().updateEncryptionCipher(RecordCipherFactory.getRecordCipher(this.tlsContext, tls13KeySet));
        }
    }

    private KeySet getTls13KeySet(TlsContext tlsContext, Tls13KeySetType tls13KeySetType) {
        try {
            LOGGER.debug("Generating new KeySet");
            return KeySetGenerator.generateKeySet(tlsContext, this.tlsContext.getChooser().getSelectedProtocolVersion(), tls13KeySetType);
        } catch (CryptoException | NoSuchAlgorithmException e) {
            throw new UnsupportedOperationException("The specified Algorithm is not supported", e);
        }
    }

    @Override // de.rub.nds.tlsattacker.core.protocol.handler.TlsMessageHandler
    public void adjustTlsContextAfterSerialize(ServerHelloMessage serverHelloMessage) {
        if (!this.tlsContext.getChooser().getSelectedProtocolVersion().isTLS13() || serverHelloMessage.isTls13HelloRetryRequest().booleanValue()) {
            return;
        }
        setServerRecordCipher();
    }

    private void adjustHandshakeTrafficSecrets(KeyShareStoreEntry keyShareStoreEntry) {
        byte[] computeSharedSecret;
        HKDFAlgorithm hKDFAlgorithm = AlgorithmResolver.getHKDFAlgorithm(this.tlsContext.getChooser().getSelectedCipherSuite());
        DigestAlgorithm digestAlgorithm = AlgorithmResolver.getDigestAlgorithm(this.tlsContext.getChooser().getSelectedProtocolVersion(), this.tlsContext.getChooser().getSelectedCipherSuite());
        try {
            byte[] deriveSecret = HKDFunction.deriveSecret(hKDFAlgorithm, digestAlgorithm.getJavaName(), HKDFunction.extract(hKDFAlgorithm, new byte[0], (this.tlsContext.getConfig().isUsePsk().booleanValue() || this.tlsContext.getPsk() != null) ? this.tlsContext.getChooser().getPsk() : new byte[Mac.getInstance(hKDFAlgorithm.getMacAlgorithm().getJavaName()).getMacLength()]), HKDFunction.DERIVED, new byte[0]);
            BigInteger keySharePrivate = this.tlsContext.getConfig().getKeySharePrivate();
            if (this.tlsContext.getChooser().getSelectedCipherSuite().isPWD()) {
                computeSharedSecret = computeSharedPWDSecret(keyShareStoreEntry);
            } else {
                computeSharedSecret = KeyShareCalculator.computeSharedSecret(keyShareStoreEntry.getGroup(), keySharePrivate, keyShareStoreEntry.getPublicKey());
                if (this.tlsContext.getConfig().getDefaultPreMasterSecret().length > 0) {
                    LOGGER.debug("Using specified PMS instead of computed PMS");
                    computeSharedSecret = this.tlsContext.getConfig().getDefaultPreMasterSecret();
                }
            }
            byte[] extract = HKDFunction.extract(hKDFAlgorithm, deriveSecret, computeSharedSecret);
            this.tlsContext.setHandshakeSecret(extract);
            LOGGER.debug("Set handshakeSecret in Context to " + ArrayConverter.bytesToHexString(extract));
            byte[] deriveSecret2 = HKDFunction.deriveSecret(hKDFAlgorithm, digestAlgorithm.getJavaName(), extract, HKDFunction.CLIENT_HANDSHAKE_TRAFFIC_SECRET, this.tlsContext.getDigest().getRawBytes());
            this.tlsContext.setClientHandshakeTrafficSecret(deriveSecret2);
            LOGGER.debug("Set clientHandshakeTrafficSecret in Context to " + ArrayConverter.bytesToHexString(deriveSecret2));
            byte[] deriveSecret3 = HKDFunction.deriveSecret(hKDFAlgorithm, digestAlgorithm.getJavaName(), extract, HKDFunction.SERVER_HANDSHAKE_TRAFFIC_SECRET, this.tlsContext.getDigest().getRawBytes());
            this.tlsContext.setServerHandshakeTrafficSecret(deriveSecret3);
            LOGGER.debug("Set serverHandshakeTrafficSecret in Context to " + ArrayConverter.bytesToHexString(deriveSecret3));
        } catch (CryptoException | NoSuchAlgorithmException e) {
            throw new AdjustmentException(e);
        }
    }

    private byte[] computeSharedPWDSecret(KeyShareStoreEntry keyShareStoreEntry) throws CryptoException {
        Chooser chooser = this.tlsContext.getChooser();
        EllipticCurve curve = CurveFactory.getCurve(keyShareStoreEntry.getGroup());
        DragonFlyKeyShareEntry parse = new DragonFlyKeyShareEntryParser(keyShareStoreEntry.getPublicKey(), keyShareStoreEntry.getGroup()).parse();
        int bitLength = curve.getModulus().bitLength();
        Point fromRawFormat = PointFormatter.fromRawFormat(keyShareStoreEntry.getGroup(), parse.getRawPublicKey());
        BigInteger scalar = parse.getScalar();
        Point computePasswordElement = PWDComputations.computePasswordElement(this.tlsContext.getChooser(), curve);
        BigInteger mod = chooser.getConnectionEndType() == ConnectionEndType.CLIENT ? new BigInteger(1, chooser.getConfig().getDefaultClientPWDPrivate()).mod(curve.getBasePointOrder()) : new BigInteger(1, chooser.getConfig().getDefaultServerPWDPrivate()).mod(curve.getBasePointOrder());
        LOGGER.debug("Element: " + ArrayConverter.bytesToHexString(PointFormatter.toRawFormat(fromRawFormat)));
        LOGGER.debug("Scalar: " + ArrayConverter.bytesToHexString(ArrayConverter.bigIntegerToByteArray(scalar)));
        return ArrayConverter.bigIntegerToByteArray(curve.mult(mod, curve.add(curve.mult(scalar, computePasswordElement), fromRawFormat)).getFieldX().getData(), bitLength / 8, true);
    }

    private void adjustHelloRetryDigest(ServerHelloMessage serverHelloMessage) {
        try {
            byte[] lastClientHello = this.tlsContext.getChooser().getLastClientHello();
            LOGGER.debug("Replacing current digest for Hello Retry Request using Client Hello: " + ArrayConverter.bytesToHexString(lastClientHello));
            MessageDigest messageDigest = MessageDigest.getInstance(AlgorithmResolver.getDigestAlgorithm(ProtocolVersion.TLS13, this.tlsContext.getChooser().getSelectedCipherSuite()).getJavaName());
            messageDigest.update(lastClientHello);
            byte[] digest = messageDigest.digest();
            byte[] bArr = (byte[]) serverHelloMessage.getCompleteResultingMessage().getValue();
            this.tlsContext.getDigest().setRawBytes(HandshakeMessageType.MESSAGE_HASH.getArrayValue());
            this.tlsContext.getDigest().append(ArrayConverter.intToBytes(digest.length, 3));
            this.tlsContext.getDigest().append(digest);
            this.tlsContext.getDigest().append(bArr);
            LOGGER.debug("Complete resulting digest: " + ArrayConverter.bytesToHexString(this.tlsContext.getDigest().getRawBytes()));
        } catch (NoSuchAlgorithmException e) {
            LOGGER.error(e);
        }
    }

    private void warnOnConflictingExtensions() {
        if (this.tlsContext.getTalkingConnectionEndType() == this.tlsContext.getChooser().getMyConnectionPeer() && !this.tlsContext.getChooser().getSelectedProtocolVersion().isTLS13() && this.tlsContext.isExtensionNegotiated(ExtensionType.MAX_FRAGMENT_LENGTH) && this.tlsContext.isExtensionNegotiated(ExtensionType.RECORD_SIZE_LIMIT)) {
            LOGGER.warn("Server sent max_fragment_length AND record_size_limit extensions");
        }
    }

    private KeyShareStoreEntry adjustKeyShareStoreEntry() {
        KeyShareStoreEntry keyShareStoreEntry;
        if (this.tlsContext.getChooser().getConnectionEndType() == ConnectionEndType.CLIENT) {
            keyShareStoreEntry = this.tlsContext.getChooser().getServerKeyShare();
        } else {
            Integer num = null;
            for (KeyShareStoreEntry keyShareStoreEntry2 : this.tlsContext.getChooser().getClientKeyShares()) {
                if (Arrays.equals(keyShareStoreEntry2.getGroup().getValue(), this.tlsContext.getChooser().getServerKeyShare().getGroup().getValue())) {
                    num = Integer.valueOf(this.tlsContext.getChooser().getClientKeyShares().indexOf(keyShareStoreEntry2));
                }
            }
            if (num == null) {
                LOGGER.warn("Client did not send the KeyShareType we expected. Choosing first in his List");
                num = 0;
            }
            keyShareStoreEntry = this.tlsContext.getChooser().getClientKeyShares().get(num.intValue());
        }
        this.tlsContext.setSelectedGroup(keyShareStoreEntry.getGroup());
        if (keyShareStoreEntry.getGroup().isCurve()) {
            this.tlsContext.setServerEcPublicKey(this.tlsContext.getChooser().getSelectedCipherSuite().isPWD() ? PointFormatter.fromRawFormat(keyShareStoreEntry.getGroup(), keyShareStoreEntry.getPublicKey()) : PointFormatter.formatFromByteArray(keyShareStoreEntry.getGroup(), keyShareStoreEntry.getPublicKey()));
        } else {
            this.tlsContext.setServerDhPublicKey(new BigInteger(keyShareStoreEntry.getPublicKey()));
        }
        return keyShareStoreEntry;
    }
}
