package dev.yavuztas.junit;

import java.lang.reflect.Method;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.junit.jupiter.api.extension.InvocationInterceptor;
import org.junit.jupiter.api.extension.ReflectiveInvocationContext;
import org.junit.platform.commons.util.AnnotationUtils;
import org.junit.platform.commons.util.ClassUtils;
import org.junit.platform.commons.util.Preconditions;
import org.junit.platform.commons.util.ReflectionUtils;

/* loaded from: input_file:dev/yavuztas/junit/ConcurrentExtension.class */
public class ConcurrentExtension implements InvocationInterceptor {
    private int globalThreadCount;

    public static ConcurrentExtension withGlobalThreadCount(int i) {
        ConcurrentExtension concurrentExtension = new ConcurrentExtension();
        concurrentExtension.globalThreadCount = i;
        return concurrentExtension;
    }

    public void interceptTestMethod(InvocationInterceptor.Invocation<Void> invocation, ReflectiveInvocationContext<Method> reflectiveInvocationContext, ExtensionContext extensionContext) throws Throwable {
        Method method = (Method) reflectiveInvocationContext.getExecutable();
        Optional findAnnotation = AnnotationUtils.findAnnotation(method, ConcurrentTest.class);
        if (!findAnnotation.isPresent()) {
            invocation.proceed();
            return;
        }
        ConcurrentTest concurrentTest = (ConcurrentTest) findAnnotation.get();
        Throwable[] thArr = new Throwable[1];
        int threadCount = threadCount(concurrentTest, method);
        ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(threadCount);
        for (int i = 0; i < threadCount; i++) {
            CompletableFuture.runAsync(() -> {
                try {
                    if (concurrentTest.printInfo()) {
                        printInfo(method);
                    }
                    ReflectionUtils.invokeMethod(method, reflectiveInvocationContext.getTarget().orElse(null), reflectiveInvocationContext.getArguments().toArray());
                } catch (Throwable th) {
                    thArr[0] = th;
                }
            }, newFixedThreadPool);
        }
        awaitTerminationAfterShutdown(newFixedThreadPool, timeout(reflectiveInvocationContext.getTargetClass(), method));
        if (thArr[0] != null) {
            throw thArr[0];
        }
        invocation.skip();
    }

    private void awaitTerminationAfterShutdown(ExecutorService executorService, Timeout timeout) {
        long value;
        executorService.shutdown();
        if (timeout != null) {
            try {
                value = timeout.value();
            } catch (InterruptedException e) {
                executorService.shutdownNow();
                Thread.currentThread().interrupt();
                return;
            }
        } else {
            value = Long.MAX_VALUE;
        }
        if (!executorService.awaitTermination(value, timeout != null ? timeout.unit() : TimeUnit.NANOSECONDS)) {
            executorService.shutdownNow();
        }
    }

    private void printInfo(Method method) {
        System.out.println(String.format("Thread#%s - %s(%s)", Long.valueOf(Thread.currentThread().getId()), method.getName(), ClassUtils.nullSafeToString((v0) -> {
            return v0.getSimpleName();
        }, method.getParameterTypes())));
    }

    private int threadCount(ConcurrentTest concurrentTest, Method method) {
        int count = concurrentTest.count();
        Preconditions.condition(count > 0, () -> {
            return String.format("Configuration error: @ConcurrentTest on method [%s] must be declared with a positive 'count'.", method);
        });
        return (concurrentTest.overrideGlobal() || this.globalThreadCount <= 0) ? count : this.globalThreadCount;
    }

    private Timeout timeout(Class<?> cls, Method method) {
        Optional findAnnotation = AnnotationUtils.findAnnotation(method, Timeout.class);
        return !findAnnotation.isPresent() ? (Timeout) AnnotationUtils.findAnnotation(cls, Timeout.class).orElse(null) : (Timeout) findAnnotation.get();
    }
}
