package org.apache.hadoop.mapreduce.v2.app;

import java.util.ArrayList;
import java.util.HashSet;
import org.apache.hadoop.mapreduce.v2.api.records.TaskType;
import org.apache.hadoop.mapreduce.v2.app.MRAppMaster;
import org.apache.hadoop.mapreduce.v2.app.job.event.JobCounterUpdateEvent;
import org.apache.hadoop.mapreduce.v2.app.job.event.TaskAttemptEvent;
import org.apache.hadoop.mapreduce.v2.app.rm.preemption.AMPreemptionPolicy;
import org.apache.hadoop.mapreduce.v2.app.rm.preemption.KillAMPreemptionPolicy;
import org.apache.hadoop.mapreduce.v2.util.MRBuilderUtils;
import org.apache.hadoop.yarn.api.records.ApplicationAttemptId;
import org.apache.hadoop.yarn.api.records.ApplicationId;
import org.apache.hadoop.yarn.api.records.Container;
import org.apache.hadoop.yarn.api.records.ContainerId;
import org.apache.hadoop.yarn.api.records.PreemptionContainer;
import org.apache.hadoop.yarn.api.records.PreemptionContract;
import org.apache.hadoop.yarn.api.records.PreemptionMessage;
import org.apache.hadoop.yarn.api.records.StrictPreemptionContract;
import org.apache.hadoop.yarn.event.Event;
import org.apache.hadoop.yarn.event.EventHandler;
import org.apache.hadoop.yarn.factories.RecordFactory;
import org.apache.hadoop.yarn.factory.providers.RecordFactoryProvider;
import org.junit.Test;
import org.mockito.ArgumentMatchers;
import org.mockito.Mockito;

/* loaded from: input_file:org/apache/hadoop/mapreduce/v2/app/TestKillAMPreemptionPolicy.class */
public class TestKillAMPreemptionPolicy {
    private final RecordFactory recordFactory = RecordFactoryProvider.getRecordFactory(null);

    @Test
    public void testKillAMPreemptPolicy() {
        ApplicationId newInstance = ApplicationId.newInstance(123456789L, 1);
        ContainerId newContainerId = ContainerId.newContainerId(ApplicationAttemptId.newInstance(newInstance, 1), 1L);
        AMPreemptionPolicy.Context context = (AMPreemptionPolicy.Context) Mockito.mock(AMPreemptionPolicy.Context.class);
        Mockito.when(context.getTaskAttempt((ContainerId) ArgumentMatchers.any(ContainerId.class))).thenReturn(MRBuilderUtils.newTaskAttemptId(MRBuilderUtils.newTaskId(MRBuilderUtils.newJobId(newInstance, 1), 1, TaskType.MAP), 0));
        ArrayList arrayList = new ArrayList();
        arrayList.add(Container.newInstance(newContainerId, null, null, null, null, null));
        Mockito.when(context.getContainers((TaskType) ArgumentMatchers.any(TaskType.class))).thenReturn(arrayList);
        KillAMPreemptionPolicy killAMPreemptionPolicy = new KillAMPreemptionPolicy();
        MRAppMaster.RunningAppContext runningAppContext = getRunningAppContext();
        killAMPreemptionPolicy.init(runningAppContext);
        killAMPreemptionPolicy.preempt(context, getPreemptionMessage(false, false, newContainerId));
        ((EventHandler) Mockito.verify(runningAppContext.getEventHandler(), Mockito.times(0))).handle((Event) ArgumentMatchers.any(TaskAttemptEvent.class));
        ((EventHandler) Mockito.verify(runningAppContext.getEventHandler(), Mockito.times(0))).handle((Event) ArgumentMatchers.any(JobCounterUpdateEvent.class));
        MRAppMaster.RunningAppContext runningAppContext2 = getRunningAppContext();
        killAMPreemptionPolicy.init(runningAppContext2);
        killAMPreemptionPolicy.preempt(context, getPreemptionMessage(true, false, newContainerId));
        ((EventHandler) Mockito.verify(runningAppContext2.getEventHandler(), Mockito.times(1))).handle((Event) ArgumentMatchers.any(TaskAttemptEvent.class));
        ((EventHandler) Mockito.verify(runningAppContext2.getEventHandler(), Mockito.times(1))).handle((Event) ArgumentMatchers.any(JobCounterUpdateEvent.class));
        MRAppMaster.RunningAppContext runningAppContext3 = getRunningAppContext();
        killAMPreemptionPolicy.init(runningAppContext3);
        killAMPreemptionPolicy.preempt(context, getPreemptionMessage(false, true, newContainerId));
        ((EventHandler) Mockito.verify(runningAppContext3.getEventHandler(), Mockito.times(1))).handle((Event) ArgumentMatchers.any(TaskAttemptEvent.class));
        ((EventHandler) Mockito.verify(runningAppContext3.getEventHandler(), Mockito.times(1))).handle((Event) ArgumentMatchers.any(JobCounterUpdateEvent.class));
        MRAppMaster.RunningAppContext runningAppContext4 = getRunningAppContext();
        killAMPreemptionPolicy.init(runningAppContext4);
        killAMPreemptionPolicy.preempt(context, getPreemptionMessage(true, true, newContainerId));
        ((EventHandler) Mockito.verify(runningAppContext4.getEventHandler(), Mockito.times(2))).handle((Event) ArgumentMatchers.any(TaskAttemptEvent.class));
        ((EventHandler) Mockito.verify(runningAppContext4.getEventHandler(), Mockito.times(2))).handle((Event) ArgumentMatchers.any(JobCounterUpdateEvent.class));
    }

    private MRAppMaster.RunningAppContext getRunningAppContext() {
        MRAppMaster.RunningAppContext runningAppContext = (MRAppMaster.RunningAppContext) Mockito.mock(MRAppMaster.RunningAppContext.class);
        Mockito.when(runningAppContext.getEventHandler()).thenReturn((EventHandler) Mockito.mock(EventHandler.class));
        return runningAppContext;
    }

    private PreemptionMessage getPreemptionMessage(boolean z, boolean z2, ContainerId containerId) {
        PreemptionMessage preemptionMessage = (PreemptionMessage) this.recordFactory.newRecordInstance(PreemptionMessage.class);
        HashSet hashSet = new HashSet();
        PreemptionContainer preemptionContainer = (PreemptionContainer) this.recordFactory.newRecordInstance(PreemptionContainer.class);
        preemptionContainer.setId(containerId);
        hashSet.add(preemptionContainer);
        if (z) {
            StrictPreemptionContract strictPreemptionContract = (StrictPreemptionContract) this.recordFactory.newRecordInstance(StrictPreemptionContract.class);
            strictPreemptionContract.setContainers(hashSet);
            preemptionMessage.setStrictContract(strictPreemptionContract);
        }
        if (z2) {
            PreemptionContract preemptionContract = (PreemptionContract) this.recordFactory.newRecordInstance(PreemptionContract.class);
            preemptionContract.setContainers(hashSet);
            preemptionMessage.setContract(preemptionContract);
        }
        return preemptionMessage;
    }
}
