Description
I fine-tuned the TrafficCamNet model (unpruned version) on my custom dataset. The original model was trained to detect the following labels:
- Car
- Bicycle
- Person
- Road Sign
My custom dataset has four new labels:
0. Heavy
- Two Wheeler
- Three Wheeler
- Four Wheeler
I trained the model for 120 epochs and successfully obtained a .tlt
model. After converting the model to ONNX format and using it for inference, the model is only detecting the “Four Wheeler” class (label 3). None of the other labels are being identified.
Here’s some additional information about my dataset:
- Class 0 (Heavy): Total occurrences = 82,089
- Class 1 (Two Wheeler): Total occurrences = 43,457
- Class 2 (Three Wheeler): Total occurrences = 29,337
- Class 3 (Four Wheeler): Total occurrences = 97,417
I believe this is a sufficient amount of data for training.
@Morganh can you please help me in getting what could be the possible reasons why the other labels are not being detected?
This is my spec.txt which is use at training time.
random_seed: 42
dataset_config {
data_sources {
tfrecords_path: "/home/trainval/tfrecords/*"
image_directory_path: "/home/trainval"
}
image_extension: "jpg"
target_class_mapping {
key: "four_wheeler"
value: "four_wheeler"
}
target_class_mapping {
key: "three_wheeler"
value: "three_wheeler"
}
target_class_mapping {
key: "two_wheeler"
value: "two_wheeler"
}
target_class_mapping {
key: "heavy"
value: "heavy"
}
validation_fold: 0
}
augmentation_config {
preprocessing {
output_image_width: 1248
output_image_height: 384
min_bbox_width: 1.0
min_bbox_height: 1.0
output_image_channel: 3
}
spatial_augmentation {
hflip_probability: 0.5
zoom_min: 1.0
zoom_max: 1.0
translate_max_x: 8.0
translate_max_y: 8.0
}
color_augmentation {
hue_rotation_max: 25.0
saturation_shift_max: 0.20000000298
contrast_scale_max: 0.10000000149
contrast_center: 0.5
}
}
postprocessing_config {
target_class_config {
key: "four_wheeler"
value {
clustering_config {
clustering_algorithm: DBSCAN
dbscan_confidence_threshold: 0.9
coverage_threshold: 0.00499999988824
dbscan_eps: 0.20000000298
dbscan_min_samples: 1
minimum_bounding_box_height: 20
}
}
}
target_class_config {
key: "three_wheeler"
value {
clustering_config {
clustering_algorithm: DBSCAN
dbscan_confidence_threshold: 0.9
coverage_threshold: 0.00499999988824
dbscan_eps: 0.15000000596
dbscan_min_samples: 1
minimum_bounding_box_height: 20
}
}
}
target_class_config {
key: "two_wheeler"
value {
clustering_config {
clustering_algorithm: DBSCAN
dbscan_confidence_threshold: 0.9
coverage_threshold: 0.00749999983236
dbscan_eps: 0.230000004172
dbscan_min_samples: 1
minimum_bounding_box_height: 20
}
}
}
target_class_config {
key: "heavy"
value {
clustering_config {
clustering_algorithm: DBSCAN
dbscan_confidence_threshold: 0.9
coverage_threshold: 0.00749999983236
dbscan_eps: 0.230000004172
dbscan_min_samples: 1
minimum_bounding_box_height: 20
}
}
}
}
model_config {
pretrained_model_file: "/home/trainval/model/resnet18_trafficcamnet.tlt"
num_layers: 18
use_batch_norm: true
objective_set {
bbox {
scale: 35.0
offset: 0.5
}
cov {
}
}
arch: "resnet"
}
evaluation_config {
validation_period_during_training: 10
first_validation_epoch: 1
minimum_detection_ground_truth_overlap {
key: "four_wheeler"
value: 0.699999988079
}
minimum_detection_ground_truth_overlap {
key: "three_wheeler"
value: 0.5
}
minimum_detection_ground_truth_overlap {
key: "two_wheeler"
value: 0.5
}
minimum_detection_ground_truth_overlap {
key: "heavy"
value: 0.5
}
evaluation_box_config {
key: "four_wheeler"
value {
minimum_height: 20
maximum_height: 9999
minimum_width: 10
maximum_width: 9999
}
}
evaluation_box_config {
key: "three_wheeler"
value {
minimum_height: 20
maximum_height: 9999
minimum_width: 10
maximum_width: 9999
}
}
evaluation_box_config {
key: "two_wheeler"
value {
minimum_height: 20
maximum_height: 9999
minimum_width: 10
maximum_width: 9999
}
}
evaluation_box_config {
key: "heavy"
value {
minimum_height: 20
maximum_height: 9999
minimum_width: 10
maximum_width: 9999
}
}
average_precision_mode: INTEGRATE
}
cost_function_config {
target_classes {
name: "four_wheeler"
class_weight: 1.0
coverage_foreground_weight: 0.0500000007451
objectives {
name: "cov"
initial_weight: 1.0
weight_target: 1.0
}
objectives {
name: "bbox"
initial_weight: 10.0
weight_target: 10.0
}
}
target_classes {
name: "three_wheeler"
class_weight: 8.0
coverage_foreground_weight: 0.0500000007451
objectives {
name: "cov"
initial_weight: 1.0
weight_target: 1.0
}
objectives {
name: "bbox"
initial_weight: 10.0
weight_target: 1.0
}
}
target_classes {
name: "two_wheeler"
class_weight: 4.0
coverage_foreground_weight: 0.0500000007451
objectives {
name: "cov"
initial_weight: 1.0
weight_target: 1.0
}
objectives {
name: "bbox"
initial_weight: 10.0
weight_target: 10.0
}
}
target_classes {
name: "heavy"
class_weight: 4.0
coverage_foreground_weight: 0.0500000007451
objectives {
name: "cov"
initial_weight: 1.0
weight_target: 1.0
}
objectives {
name: "bbox"
initial_weight: 10.0
weight_target: 10.0
}
}
enable_autoweighting: false
max_objective_weight: 0.999899983406
min_objective_weight: 9.99999974738e-05
}
training_config {
batch_size_per_gpu: 4
num_epochs: 120
learning_rate {
soft_start_annealing_schedule {
min_learning_rate: 5e-07
max_learning_rate: 5e-05
soft_start: 0.10000000149
annealing: 0.699999988079
}
}
regularizer {
type: L1
weight: 3.00000002618e-09
}
optimizer {
adam {
epsilon: 9.99999993923e-09
beta1: 0.899999976158
beta2: 0.999000012875
}
}
cost_scaling {
initial_exponent: 20.0
increment: 0.005
decrement: 1.0
}
checkpoint_interval: 10
}
bbox_rasterizer_config {
target_class_config {
key: "four_wheeler"
value {
cov_center_x: 0.5
cov_center_y: 0.5
cov_radius_x: 0.40000000596
cov_radius_y: 0.40000000596
bbox_min_radius: 1.0
}
}
target_class_config {
key: "three_wheeler"
value {
cov_center_x: 0.5
cov_center_y: 0.5
cov_radius_x: 1.0
cov_radius_y: 1.0
bbox_min_radius: 1.0
}
}
target_class_config {
key: "two_wheeler"
value {
cov_center_x: 0.5
cov_center_y: 0.5
cov_radius_x: 1.0
cov_radius_y: 1.0
bbox_min_radius: 1.0
}
}
target_class_config {
key: "heavy"
value {
cov_center_x: 0.5
cov_center_y: 0.5
cov_radius_x: 1.0
cov_radius_y: 1.0
bbox_min_radius: 1.0
}
}
deadzone_radius: 0.400000154972
}