package org.datavec.api.records.reader.impl.regex;

import java.io.BufferedInputStream;
import java.io.DataInputStream;
import java.io.File;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.net.URI;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.datavec.api.conf.Configuration;
import org.datavec.api.records.SequenceRecord;
import org.datavec.api.records.metadata.RecordMetaData;
import org.datavec.api.records.metadata.RecordMetaDataURI;
import org.datavec.api.records.reader.SequenceRecordReader;
import org.datavec.api.records.reader.impl.FileRecordReader;
import org.datavec.api.split.InputSplit;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/datavec/api/records/reader/impl/regex/RegexSequenceRecordReader.class */
public class RegexSequenceRecordReader extends FileRecordReader implements SequenceRecordReader {
    public static final String SKIP_NUM_LINES = NAME_SPACE + ".skipnumlines";
    public static final Charset DEFAULT_CHARSET = Charset.forName("UTF-8");
    public static final LineErrorHandling DEFAULT_ERROR_HANDLING = LineErrorHandling.FailOnInvalid;
    public static final Logger LOG = LoggerFactory.getLogger((Class<?>) RegexSequenceRecordReader.class);
    private String regex;
    private int skipNumLines;
    private Pattern pattern;
    private transient Charset charset;
    private LineErrorHandling errorHandling;

    /* loaded from: input_file:org/datavec/api/records/reader/impl/regex/RegexSequenceRecordReader$LineErrorHandling.class */
    public enum LineErrorHandling {
        FailOnInvalid,
        SkipInvalid,
        SkipInvalidWithWarning
    }

    public RegexSequenceRecordReader(String str, int i) {
        this(str, i, DEFAULT_CHARSET, DEFAULT_ERROR_HANDLING);
    }

    public RegexSequenceRecordReader(String str, int i, Charset charset, LineErrorHandling lineErrorHandling) {
        this.regex = str;
        this.skipNumLines = i;
        this.pattern = Pattern.compile(str);
        this.charset = charset;
        this.errorHandling = lineErrorHandling;
    }

    @Override // org.datavec.api.records.reader.impl.FileRecordReader, org.datavec.api.records.reader.RecordReader
    public void initialize(Configuration configuration, InputSplit inputSplit) throws IOException, InterruptedException {
        super.initialize(configuration, inputSplit);
        this.skipNumLines = configuration.getInt(SKIP_NUM_LINES, this.skipNumLines);
    }

    @Override // org.datavec.api.records.reader.SequenceRecordReader
    public List<List<Writable>> sequenceRecord() {
        return nextSequence().getSequenceRecord();
    }

    @Override // org.datavec.api.records.reader.SequenceRecordReader
    public List<List<Writable>> sequenceRecord(URI uri, DataInputStream dataInputStream) throws IOException {
        return loadSequence(IOUtils.toString(new BufferedInputStream(dataInputStream), this.charset.name()), uri);
    }

    private List<List<Writable>> loadSequence(String str, URI uri) {
        String[] split = str.split("(\r\n)|\n");
        int i = 0;
        ArrayList arrayList = new ArrayList();
        int i2 = 0;
        for (String str2 : split) {
            i2++;
            if (i < this.skipNumLines) {
                i++;
            } else {
                Matcher matcher = this.pattern.matcher(str2);
                if (matcher.matches()) {
                    int groupCount = matcher.groupCount();
                    ArrayList arrayList2 = new ArrayList(groupCount);
                    for (int i3 = 1; i3 <= groupCount; i3++) {
                        arrayList2.add(new Text(matcher.group(i3)));
                    }
                    arrayList.add(arrayList2);
                } else {
                    switch (this.errorHandling) {
                        case FailOnInvalid:
                            throw new IllegalStateException("Invalid line: line does not match regex (line #" + i2 + ", uri=\"" + uri + "\"), \", regex=" + this.regex + "\"; line=\"" + str2 + "\"");
                        case SkipInvalid:
                            break;
                        case SkipInvalidWithWarning:
                            LOG.warn("Skipping invalid line: line does not match regex (line #" + i2 + ", uri=\"" + uri + "\"), \"; line=\"" + str2 + "\"");
                            break;
                        default:
                            throw new RuntimeException("Unknown error handling mode: " + this.errorHandling);
                    }
                }
            }
        }
        return arrayList;
    }

    @Override // org.datavec.api.records.reader.impl.FileRecordReader, org.datavec.api.records.reader.RecordReader
    public void reset() {
        super.reset();
    }

    @Override // org.datavec.api.records.reader.SequenceRecordReader
    public SequenceRecord nextSequence() {
        File nextFile = nextFile();
        try {
            return new org.datavec.api.records.impl.SequenceRecord(loadSequence(FileUtils.readFileToString(nextFile, this.charset.name()), nextFile.toURI()), new RecordMetaDataURI(nextFile.toURI(), RegexSequenceRecordReader.class));
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    @Override // org.datavec.api.records.reader.SequenceRecordReader
    public SequenceRecord loadSequenceFromMetaData(RecordMetaData recordMetaData) throws IOException {
        return loadSequenceFromMetaData(Collections.singletonList(recordMetaData)).get(0);
    }

    @Override // org.datavec.api.records.reader.SequenceRecordReader
    public List<SequenceRecord> loadSequenceFromMetaData(List<RecordMetaData> list) throws IOException {
        ArrayList arrayList = new ArrayList();
        for (RecordMetaData recordMetaData : list) {
            File file = new File(recordMetaData.getURI());
            arrayList.add(new org.datavec.api.records.impl.SequenceRecord(loadSequence(FileUtils.readFileToString(file, this.charset.name()), file.toURI()), recordMetaData));
        }
        return arrayList;
    }

    private void readObject(ObjectInputStream objectInputStream) throws ClassNotFoundException, IOException {
        objectInputStream.defaultReadObject();
        this.charset = Charset.forName(objectInputStream.readUTF());
    }

    private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
        objectOutputStream.defaultWriteObject();
        objectOutputStream.writeUTF(this.charset.name());
    }
}
