package io.unitycatalog.server.auth.decorator;

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.linecorp.armeria.common.HttpData;
import com.linecorp.armeria.common.HttpRequest;
import com.linecorp.armeria.common.HttpResponse;
import com.linecorp.armeria.common.MediaType;
import com.linecorp.armeria.internal.server.annotation.AnnotatedService;
import com.linecorp.armeria.server.DecoratingHttpServiceFunction;
import com.linecorp.armeria.server.HttpService;
import com.linecorp.armeria.server.ServiceRequestContext;
import com.linecorp.armeria.server.SimpleDecoratingHttpService;
import com.linecorp.armeria.server.annotation.Param;
import io.unitycatalog.server.auth.UnityCatalogAuthorizer;
import io.unitycatalog.server.auth.annotation.AuthorizeExpression;
import io.unitycatalog.server.auth.annotation.AuthorizeKey;
import io.unitycatalog.server.auth.annotation.AuthorizeKeys;
import io.unitycatalog.server.auth.decorator.KeyLocator;
import io.unitycatalog.server.exception.BaseException;
import io.unitycatalog.server.exception.ErrorCode;
import io.unitycatalog.server.model.SecurableType;
import io.unitycatalog.server.utils.IdentityUtils;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:io/unitycatalog/server/auth/decorator/UnityAccessDecorator.class */
public class UnityAccessDecorator implements DecoratingHttpServiceFunction {
    private static final Logger LOGGER = LoggerFactory.getLogger(UnityAccessDecorator.class);
    public static final ObjectMapper MAPPER = new ObjectMapper();
    private final UnityAccessEvaluator evaluator;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/unitycatalog/server/auth/decorator/UnityAccessDecorator$PeekDataHandler.class */
    public static class PeekDataHandler {
        private final MediaType contentType;
        private final List<KeyLocator> payloadLocators;
        private final Map<SecurableType, Object> resourceKeys;
        private final ByteArrayOutputStream dataStream = new ByteArrayOutputStream();

        private PeekDataHandler(MediaType mediaType, List<KeyLocator> list, Map<SecurableType, Object> map) {
            this.contentType = mediaType;
            this.payloadLocators = list;
            this.resourceKeys = map;
        }

        private boolean processPeekData(HttpData httpData) {
            if (!this.contentType.equals(MediaType.JSON)) {
                UnityAccessDecorator.LOGGER.warn("Skipping content-type: {}", this.contentType);
                return false;
            }
            try {
                this.dataStream.write(httpData.array());
                UnityAccessDecorator.LOGGER.debug("Payload: {}", this.dataStream.toString());
            } catch (IOException e) {
            }
            if (httpData.array()[httpData.array().length - 1] != 125) {
                return false;
            }
            try {
                Map map = (Map) UnityAccessDecorator.MAPPER.readValue(this.dataStream.toByteArray(), new TypeReference<Map<String, Object>>() { // from class: io.unitycatalog.server.auth.decorator.UnityAccessDecorator.PeekDataHandler.1
                });
                this.payloadLocators.forEach(keyLocator -> {
                    this.resourceKeys.put(keyLocator.getType(), UnityAccessDecorator.findPayloadValue(keyLocator.getKey(), map));
                });
                return true;
            } catch (IOException e2) {
                UnityAccessDecorator.LOGGER.warn("Error parsing payload: {}", e2.getMessage());
                return false;
            }
        }
    }

    public UnityAccessDecorator(UnityCatalogAuthorizer unityCatalogAuthorizer) throws BaseException {
        try {
            this.evaluator = new UnityAccessEvaluator(unityCatalogAuthorizer);
        } catch (IllegalAccessException | NoSuchMethodException e) {
            throw new BaseException(ErrorCode.INTERNAL, "Error initializing access evaluator.", e);
        }
    }

    public HttpResponse serve(HttpService httpService, ServiceRequestContext serviceRequestContext, HttpRequest httpRequest) throws Exception {
        LOGGER.debug("AccessDecorator checking {}", httpRequest.path());
        Method findServiceMethod = findServiceMethod(serviceRequestContext.config().service());
        if (findServiceMethod != null) {
            String findAuthorizeExpression = findAuthorizeExpression(findServiceMethod);
            List<KeyLocator> findAuthorizeKeys = findAuthorizeKeys(findServiceMethod);
            if (findAuthorizeExpression == null) {
                LOGGER.debug("No authorization expression found.");
            } else {
                if (!findAuthorizeKeys.isEmpty()) {
                    return authorizeByRequest(httpService, serviceRequestContext, httpRequest, IdentityUtils.findPrincipalId(), findAuthorizeKeys, findAuthorizeExpression);
                }
                LOGGER.warn("No authorization resource(s) found.");
            }
        } else {
            LOGGER.warn("Couldn't unwrap service.");
        }
        return httpService.serve(serviceRequestContext, httpRequest);
    }

    private HttpResponse authorizeByRequest(HttpService httpService, ServiceRequestContext serviceRequestContext, HttpRequest httpRequest, UUID uuid, List<KeyLocator> list, String str) throws Exception {
        HashMap hashMap = new HashMap();
        List<KeyLocator> list2 = list.stream().filter(keyLocator -> {
            return keyLocator.getSource().equals(KeyLocator.Source.SYSTEM);
        }).toList();
        List<KeyLocator> list3 = list.stream().filter(keyLocator2 -> {
            return keyLocator2.getSource().equals(KeyLocator.Source.PARAM);
        }).toList();
        List<KeyLocator> list4 = list.stream().filter(keyLocator3 -> {
            return keyLocator3.getSource().equals(KeyLocator.Source.PAYLOAD);
        }).toList();
        list2.forEach(keyLocator4 -> {
            hashMap.put(keyLocator4.getType(), "metastore");
        });
        list3.forEach(keyLocator5 -> {
            hashMap.put(keyLocator5.getType(), serviceRequestContext.pathParam(keyLocator5.getKey()) != null ? serviceRequestContext.pathParam(keyLocator5.getKey()) : serviceRequestContext.queryParam(keyLocator5.getKey()));
        });
        if (list4.isEmpty()) {
            LOGGER.debug("Checking authorization before method.");
            checkAuthorization(uuid, str, hashMap);
            return httpService.serve(serviceRequestContext, httpRequest);
        }
        LOGGER.debug("Checking authorization before in peekData.");
        PeekDataHandler peekDataHandler = new PeekDataHandler(httpRequest.contentType(), list4, hashMap);
        return httpService.serve(serviceRequestContext, httpRequest.peekData(httpData -> {
            LOGGER.debug("Authorization peekData invoked.");
            if (peekDataHandler.processPeekData(httpData)) {
                checkAuthorization(uuid, str, hashMap);
            }
        }));
    }

    private static Object findPayloadValue(String str, Map<String, Object> map) {
        String[] split = str.split("[.]", 2);
        if (split.length == 1) {
            return map.get(split[0]);
        }
        if (!(map.get(split[0]) instanceof Map)) {
            return null;
        }
        return findPayloadValue(split[1], (Map) map.get(split[0]));
    }

    private void checkAuthorization(UUID uuid, String str, Map<SecurableType, Object> map) {
        LOGGER.debug("resourceKeys = {}", map);
        Map<SecurableType, Object> mapResourceKeys = KeyMapperUtil.mapResourceKeys(map);
        if (!mapResourceKeys.keySet().containsAll(map.keySet())) {
            LOGGER.warn("Some resource keys have unresolved ids.");
        }
        LOGGER.debug("resourceIds = {}", mapResourceKeys);
        if (!this.evaluator.evaluate(uuid, str, mapResourceKeys)) {
            throw new BaseException(ErrorCode.PERMISSION_DENIED, "Access denied.");
        }
    }

    private static String findAuthorizeExpression(Method method) {
        AuthorizeExpression authorizeExpression = (AuthorizeExpression) method.getAnnotation(AuthorizeExpression.class);
        if (authorizeExpression != null) {
            LOGGER.debug("authorize expression = {}", authorizeExpression.value());
            return authorizeExpression.value();
        }
        LOGGER.debug("authorize = (none found)");
        return null;
    }

    private static List<KeyLocator> findAuthorizeKeys(Method method) {
        ArrayList arrayList = new ArrayList();
        AuthorizeKey authorizeKey = (AuthorizeKey) method.getAnnotation(AuthorizeKey.class);
        if (authorizeKey != null) {
            arrayList.add(KeyLocator.builder().source(KeyLocator.Source.SYSTEM).type(authorizeKey.value()).build());
        }
        for (Parameter parameter : method.getParameters()) {
            AuthorizeKey authorizeKey2 = (AuthorizeKey) parameter.getAnnotation(AuthorizeKey.class);
            AuthorizeKeys authorizeKeys = (AuthorizeKeys) parameter.getAnnotation(AuthorizeKeys.class);
            if (authorizeKey2 != null && authorizeKeys != null) {
                LOGGER.warn("Both AuthorizeKey and AuthorizeKeys present");
            }
            ArrayList<AuthorizeKey> arrayList2 = new ArrayList();
            if (authorizeKey2 != null) {
                arrayList2.add(authorizeKey2);
            }
            if (authorizeKeys != null) {
                arrayList2.addAll(Arrays.asList(authorizeKeys.value()));
            }
            for (AuthorizeKey authorizeKey3 : arrayList2) {
                if (authorizeKey3.key().isEmpty()) {
                    Param annotation = parameter.getAnnotation(Param.class);
                    if (annotation != null) {
                        arrayList.add(KeyLocator.builder().source(KeyLocator.Source.PARAM).type(authorizeKey3.value()).key(annotation.value()).build());
                    } else {
                        LOGGER.warn("Couldn't find param key for authorization key");
                    }
                } else {
                    arrayList.add(KeyLocator.builder().source(KeyLocator.Source.PAYLOAD).type(authorizeKey3.value()).key(authorizeKey3.key()).build());
                }
            }
        }
        return arrayList;
    }

    private static Method findServiceMethod(HttpService httpService) throws ClassNotFoundException {
        SimpleDecoratingHttpService unwrap = httpService.unwrap();
        if (!(unwrap instanceof SimpleDecoratingHttpService)) {
            return null;
        }
        AnnotatedService unwrap2 = unwrap.unwrap();
        if (!(unwrap2 instanceof AnnotatedService)) {
            return null;
        }
        AnnotatedService annotatedService = unwrap2;
        LOGGER.debug("serviceName = {}, methodName = {}", annotatedService.serviceName(), annotatedService.methodName());
        List<Method> findMethodsByName = findMethodsByName(Class.forName(annotatedService.serviceName()), annotatedService.methodName());
        if (findMethodsByName.size() == 1) {
            return findMethodsByName.get(0);
        }
        return null;
    }

    private static List<Method> findMethodsByName(Class<?> cls, String str) {
        ArrayList arrayList = new ArrayList();
        for (Method method : cls.getDeclaredMethods()) {
            if (method.getName().equals(str)) {
                arrayList.add(method);
            }
        }
        return arrayList;
    }
}
