package org.nd4j.common.tests;

import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Set;
import org.junit.jupiter.api.Test;
import org.reflections.Reflections;
import org.reflections.scanners.MethodAnnotationsScanner;
import org.reflections.scanners.Scanner;
import org.reflections.util.ClasspathHelper;
import org.reflections.util.ConfigurationBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/common/tests/AbstractAssertTestsClass.class */
public abstract class AbstractAssertTestsClass extends BaseND4JTest {
    private static final Logger log = LoggerFactory.getLogger(AbstractAssertTestsClass.class);

    protected abstract Set<Class<?>> getExclusions();

    protected abstract String getPackageName();

    protected abstract Class<?> getBaseClass();

    @Override // org.nd4j.common.tests.BaseND4JTest
    public long getTimeoutMilliseconds() {
        return 240000L;
    }

    @Test
    public void checkTestClasses() {
        Set methodsAnnotatedWith = new Reflections(new ConfigurationBuilder().setUrls(ClasspathHelper.forPackage(getPackageName(), new ClassLoader[0])).setScanners(new Scanner[]{new MethodAnnotationsScanner()})).getMethodsAnnotatedWith(Test.class);
        HashSet hashSet = new HashSet();
        Iterator it = methodsAnnotatedWith.iterator();
        while (it.hasNext()) {
            hashSet.add(((Method) it.next()).getDeclaringClass());
        }
        ArrayList<Class<?>> arrayList = new ArrayList(hashSet);
        Collections.sort(arrayList, new Comparator<Class<?>>() { // from class: org.nd4j.common.tests.AbstractAssertTestsClass.1
            @Override // java.util.Comparator
            public int compare(Class<?> cls, Class<?> cls2) {
                return cls.getName().compareTo(cls2.getName());
            }
        });
        int i = 0;
        for (Class<?> cls : arrayList) {
            if (!getBaseClass().isAssignableFrom(cls) && !getExclusions().contains(cls)) {
                log.error("Test {} does not extend {} (directly or indirectly). All tests must extend this class for proper memory tracking and timeouts", cls, getBaseClass());
                i++;
            }
        }
    }
}
