将 CSV 文件转换为 TF 记录

2022-01-21 00:00:00 python tensorflow csv file-io dataset

问题描述

我已经运行我的脚本超过 5 个小时了.我有 258 个 CSV 文件要转换为 TF 记录.我编写了以下脚本,正如我所说,我已经运行了 5 个多小时:

I've been running my script for more than 5 hours already. I have 258 CSV files that I want to convert to TF Records. I wrote the following script, and as I've said, I've been running it for more than 5 hours already:

import argparse
import os
import sys
import standardize_data
import tensorflow as tf

FLAGS = None
PATH = '/home/darth/GitHub Projects/gru_svm/dataset/train'

def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def _float_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def convert_to(dataset, name):
    """Converts a dataset to tfrecords"""

    filename_queue = tf.train.string_input_producer(dataset)

    # TF reader
    reader = tf.TextLineReader()

    # default values, in case of empty columns
    record_defaults = [[0.0] for x in range(24)]

    key, value = reader.read(filename_queue)

    duration, service, src_bytes, dest_bytes, count, same_srv_rate, 
    serror_rate, srv_serror_rate, dst_host_count, dst_host_srv_count, 
    dst_host_same_src_port_rate, dst_host_serror_rate, dst_host_srv_serror_rate, 
    flag, ids_detection, malware_detection, ashula_detection, label, src_ip_add, 
    src_port_num, dst_ip_add, dst_port_num, start_time, protocol = 
    tf.decode_csv(value, record_defaults=record_defaults)

    features = tf.stack([duration, service, src_bytes, dest_bytes, count, same_srv_rate,
                        serror_rate, srv_serror_rate, dst_host_count, dst_host_srv_count,
                        dst_host_same_src_port_rate, dst_host_serror_rate, dst_host_srv_serror_rate,
                        flag, ids_detection, malware_detection, ashula_detection, src_ip_add,
                        src_port_num, dst_ip_add, dst_port_num, start_time, protocol])

    filename = os.path.join(FLAGS.directory, name + '.tfrecords')
    print('Writing {}'.format(filename))
    writer = tf.python_io.TFRecordWriter(filename)
    with tf.Session() as sess:
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        try:
            while not coord.should_stop():
                example, l = sess.run([features, label])
                print('Writing {dataset} : {example}, {label}'.format(dataset=sess.run(key),
                        example=example, label=l))
                example_to_write = tf.train.Example(features=tf.train.Features(feature={
                    'duration' : _float_feature(example[0]),
                    'service' : _int64_feature(int(example[1])),
                    'src_bytes' : _float_feature(example[2]),
                    'dest_bytes' : _float_feature(example[3]),
                    'count' : _float_feature(example[4]),
                    'same_srv_rate' : _float_feature(example[5]),
                    'serror_rate' : _float_feature(example[6]),
                    'srv_serror_rate' : _float_feature(example[7]),
                    'dst_host_count' : _float_feature(example[8]),
                    'dst_host_srv_count' : _float_feature(example[9]),
                    'dst_host_same_src_port_rate' : _float_feature(example[10]),
                    'dst_host_serror_rate' : _float_feature(example[11]),
                    'dst_host_srv_serror_rate' : _float_feature(example[12]),
                    'flag' : _int64_feature(int(example[13])),
                    'ids_detection' : _int64_feature(int(example[14])),
                    'malware_detection' : _int64_feature(int(example[15])),
                    'ashula_detection' : _int64_feature(int(example[16])),
                    'label' : _int64_feature(int(l)),
                    'src_ip_add' : _float_feature(example[17]),
                    'src_port_num' : _float_feature(example[18]),
                    'dst_ip_add' : _float_feature(example[19]),
                    'dst_port_num' : _float_feature(example[20]),
                    'start_time' : _float_feature(example[21]),
                    'protocol' : _int64_feature(int(example[22])),
                    }))
                writer.write(example_to_write.SerializeToString())
            writer.close()
        except tf.errors.OutOfRangeError:
            print('Done converting -- EOF reached.')
        finally:
            coord.request_stop()

        coord.join(threads)

def main(unused_argv):
    files = standardize_data.list_files(path=PATH)

    convert_to(dataset=files, name='train')

它已经让我想到它可能陷入了无限循环?我想要做的是读取每个 CSV 文件(258 个 CSV 文件)中的所有行,并将这些行写入 TF 记录(当然是一个特征和一个标签).然后,当没有更多行可用或 CSV 文件已经用完时停止循环.

It already got me thinking that perhaps it's stuck in an infinite loop? What I want to do is to read all rows in each CSV file (258 CSV files), and write those rows into a TF Record (a feature and a label, that is, of course). And then, stop the loop when there are no more rows available, or the CSV files have been exhausted already.

standardize_data.list_files(path) 是我在不同模块中编写的函数.我只是将它重新用于这个脚本.它的作用是返回在 PATH 中找到的所有文件的列表.请注意,我的 PATH 中的文件仅包含 CSV 文件.

The standardize_data.list_files(path) is a function I wrote in a different module. I just re-used it for this script. What it does is to return a list of all the files found in PATH. Take note that the files in my PATH only contains CSV files.


解决方案

string_input_producer 中设置 num_epochs=1.另一个注意事项:将这些 csv 转换为 tfrecords 可能不会提供您在 tfrecords 中查看的任何优势,这种数据的开销非常高(具有大量单个特征/标签).您可能想试验这部分.

Set num_epochs=1 in string_input_producer. Another note: Converting these csv to tfrecords may not offer any advantage you are looking in tfrecords, the overheads is very high with this kind of data (with the large number of single features/labels). You may want to experiment this part.

相关文章