训练环境设置
文件地址
在使用 Tao 训练模型时,需要设置两套文件地址:
- USER_EXPERIMENT_DIR:docker 中的实验文档存放地址
- LOCAL_PROJECT_DIR:系统中的实验文档存放地址
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
| import os
%env KEY="之前在 NGC 官网申请的秘钥" %env GPU_INDEX=0
%env USER_EXPERIMENT_DIR=/home/alex/7th_sky_hackathon
%env DATA_DOWNLOAD_DIR=/home/alex/7th_sky_hackathon/data
%env SPECS_DIR=/home/alex/7th_sky_hackathon/7th/specs
%env LOCAL_PROJECT_DIR=/home/alex/7th_sky_hackathon
os.environ["LOCAL_DATA_DIR"] = os.path.join(os.getenv("LOCAL_PROJECT_DIR", os.getcwd()), "data")
os.environ["LOCAL_EXPERIMENT_DIR"] = os.path.join(os.getenv("LOCAL_PROJECT_DIR", os.getcwd()), "7th")
os.environ["LOCAL_SPECS_DIR"] = os.path.join(os.getenv("NOTEBOOK_ROOT", os.getcwd()), "specs")
|
设置映射关系
将两套文件地址设置好映射后,只需对系统地址进行管理,docker 系统中的地址将会自动进行管理
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
| import json
mounts_file = os.path.expanduser("~/.tao_mounts.json")
drive_map = { "Mounts": [ { "source": os.environ["LOCAL_PROJECT_DIR"], "destination": "/home/alex/7th_sky_hackathon/" }, { "source": os.environ["LOCAL_SPECS_DIR"], "destination": os.environ["SPECS_DIR"] } ] }
with open(mounts_file, "w") as mfile: json.dump(drive_map, mfile, indent=4)
|
数据集处理
数据集简介
本次比赛,官方给出了 Pascal VOC 数据集,一共提供了三个压缩包,具体信息如下表
每个压缩包中包含相应图片和每张图片打完标签所对应的 XML 文件,每个样例以及相应标签的解释如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41
| <annotation> <folder></folder> <filename>cardboard1_jpg.rf.751f28e0ab21477466e5b0620fb70dde.jpg</filename> <path>cardboard1_jpg.rf.751f28e0ab21477466e5b0620fb70dde.jpg</path> <source> <database>roboflow.ai</database> </source> <size> <width>416</width> <height>416</height> <depth>3</depth> </size> <segmented>0</segmented> <object> <name>CARDBOARD</name> <pose>Unspecified</pose> <truncated>0</truncated> <difficult>0</difficult> <occluded>0</occluded> <bndbox> <xmin>29</xmin> <xmax>387</xmax> <ymin>8</ymin> <ymax>338</ymax> </bndbox> </object> </annotation>
|
需要注意的是,除官方所提供的数据集外,还可以使用自己的数据集,同时,由于官方给出的数据集中部分数据样本有问题,需要进行数据清洗与重复标注,重复标注阶段使用的工具是 labelimg
数据集要求
比赛要求将 VOC 标注格式转为 Kitti 标注格式,标注样例如下:
1
| banane 0.00 0 0.0 56.00 5.00 338.00 415.00 0.0 0.0 0.0 0.0 0.0 0.0 0.0
|
VOC 标注格式转 Kitti 标注格式
VOC 标注格式转为 Kitti 标注格式转换脚本如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
| import os import xml.etree.ElementTree as ET from tqdm import tqdm
def convert(xml_dir, kitti_dir, filename): file = open(xml_dir + filename, "rb") root = ET.parse(file).getroot()
for obj in root.iter("object"): clazz = obj.find("name").text if clazz == "bottle": clazz = "bottle" elif clazz == "CARDBOARD": clazz = "cardboard" elif clazz == "banane" or clazz == "0": clazz = "banane" else: clazz = ""
bndbox = obj.find("bndbox") xmin = int(bndbox.find("xmin").text) xmax = int(bndbox.find("xmax").text) ymin = int(bndbox.find("ymin").text) ymax = int(bndbox.find("ymax").text)
if clazz == "bottle" or clazz == "cardboard" or clazz == "banane": clazz = clazz + " 0.00 0 0.0 " + str(xmin) + " " + str(ymin) + " " + str(xmax) + " " + str(ymax) + " 0.0 0.0 0.0 0.0 0.0 0.0 0.0\n" with open(kitti_dir + filename + ".txt", "w") as f: f.write(clazz)
if __name__ == "__main__": xml_dir = "/home/alex/7th_sky_hackathon/data/xml" kitti_dir = "/home/alex/7th_sky_hackathon/data/kitti"
xml_list = os.listdir(xml_dir) for i in tqdm(range(len(xml_list))): convert(xml_dir, kitti_dir, xml_list[i])
|
生成验证集
使用官方提供的脚本 generate_val_dataset.py
来生成验证集
1
| python3.6 generate_val_dataset.py -- input_image_dir=$LOCAL_DATA_DIR/images --input_label_dir=$LOCAL_DATA_DIR/labels --output_dir=$LOCAL_DATA_DIR/val
|
脚本内容如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
| from __future__ import absolute_import from __future__ import division from __future__ import print_function
import argparse import os
def parse_args(args=None): parser = argparse.ArgumentParser(description='生成验证数据集')
parser.add_argument( "--input_image_dir", type=str, required=True, help="Kitti 训练集图像的输入目录" )
parser.add_argument( "--input_label_dir", type=str, required=True, help="Kitti 训练集标注的输入目录" )
parser.add_argument( "--output_dir", type=str, required=True, help="验证集输出目录" )
parser.add_argument( "--val_split", type=int, required=False, default=10, help="验证集占训练集的百分比" )
return parser.parse_args(args)
def main(args=None):
args = parse_args(args)
img_files = [] for file_name in os.listdir(args.input_image_dir): if file_name.split(".")[-1] == "jpg": img_files.append(file_name) total_cnt = len(img_files)
val_ratio = float(args.val_split) / 100.0 val_cnt = int(total_cnt * val_ratio) train_cnt = total_cnt - val_cnt val_img_list = img_files[0:val_cnt] target_img_path = os.path.join(args.output_dir, "image") target_label_path = os.path.join(args.output_dir, "label")
if not os.path.exists(target_img_path): os.makedirs(target_img_path) else: print("图像输出路径已存在") return
if not os.path.exists(target_label_path): os.makedirs(target_label_path) else: print("标注输出路径已存在") return
print("总计 {} 个 Kitti 样本".format(total_cnt)) print("划分出 {} 个训练样本,以及 {} 个验证样本".format(train_cnt, val_cnt))
for img_name in val_img_list: label_name = img_name.split(".jpg")[0] + ".txt" os.rename(os.path.join(args.input_image_dir, img_name), os.path.join(target_img_path, img_name)) os.rename(os.path.join(args.input_label_dir, label_name), os.path.join(target_label_path, label_name))
if __name__ == "__main__": main()
|
将 Kitti 格式数据集转为 TFRecords 格式
根据实际目录配置,修改 specs/ssd_tfrecords_kitti_train.txt
中的内容,其中需要修改的内容如下:
字段 |
含义 |
root_directory_path |
docker 环境中的地址 |
image_dir_name |
保存图片数据的文件夹 |
label_dir_name |
保存标注文件的文件夹 |
target_class_mapping |
标注的类别,key 为标注文件中写的类别,value 为 Tao 训练时使用的类别 |
修改后的内容如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
| kitti_config { root_directory_path: "/home/alex/7th_sky_hackathon/data" image_dir_name: "images" label_dir_name: "labels" image_extension: ".jpg" partition_mode: "random" num_partitions: 2 val_split: 0 num_shards: 10 } image_directory_path: "/home/alex/7th_sky_hackathon/data" target_class_mapping { key: "cardboard" value: "cardboard" } target_class_mapping { key: "bottle" value: "bottle" } target_class_mapping { key: "banane" value: "banane" }
|
之后,在 $LOCAL_DATA_DIR
目录下,创建用于保存 tfrecord 格式数据的文件夹 tfrecords
,再利用 tao 自带的工具转换数据,将训练集转为 tfrecord 格式
1 2
| $LOCAL_DATA_DIR/tfrecords/* !tao ssd dataset_convert -d home/alex/7th_sky_hackathon/7th/specs/ssd_tfrecords_kitti_train.txt -o /home/alex/7th_sky_hackathon/data/tfrecords/kitti_train
|