-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathNetTIME_CRF_evaluate.py
More file actions
124 lines (112 loc) · 3.06 KB
/
NetTIME_CRF_evaluate.py
File metadata and controls
124 lines (112 loc) · 3.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import argparse
from NetTIME import CRFEvaluateWorkflow
######## User Input ########
parser = argparse.ArgumentParser(
"Evaluate a linear chain CRF classifier trained on NetTIME predictions."
)
# Validation parameters
parser.add_argument(
"--batch_size",
type=int,
default=2700,
help="Evaluation batch size. Default: 2700",
)
parser.add_argument(
"--num_workers",
type=int,
default=10,
help="Number of workers used to perform multi-process data loading. "
"Default: 10",
)
parser.add_argument(
"--seed", type=int, default=1111, help="Random seed. Default: 1111"
)
parser.add_argument(
"--model_config",
type=str,
default=None,
help="Specify an alternative path to CRF .config file.",
)
parser.add_argument(
"--ckpt_dir",
type=str,
default=None,
help="Specify an alternative location to find checkpoints to be evaluated.",
)
# Data
parser.add_argument(
"--dataset",
type=str,
default="data/datasets/training_example/validation_minOverlap200_maxUnion600_example.h5",
help="Path to NetTIME evaluation data containing target labels. Default: "
"data/datasets/training_example/validation_minOverlap200_maxUnion600_example.h5",
)
parser.add_argument(
"--prediction_dir",
type=str,
default=None,
help="Path to NetTIME prediction directory.",
)
parser.add_argument(
"--dtype",
type=str,
default="VALIDATION",
help="Dataset type. Default: VALIDATION.",
)
parser.add_argument(
"--class_weight",
type=str,
default="data/datasets/training_example/validation_minOverlap200_maxUnion600_example_weight.npy",
help="Path to a numpy .npy file specifying the class weight. Default: "
"data/datasets/training_example/validation_minOverlap200_maxUnion600_example_weight.npy",
)
# Save
parser.add_argument(
"--output_dir",
type=str,
default="experiments/",
help="Root directory for saving experiment results."
"Default: experiments/",
)
parser.add_argument(
"--experiment_name",
type=str,
default="training_example",
help="experiment name.",
)
parser.add_argument(
"--result_dir",
type=str,
default=None,
help="Specify an alternative location to save checkpoint evaluation files.",
)
parser.add_argument(
"--tmp_dir",
type=str,
default="/tmp",
help="Temporary directory for saving merged prediction .h5 file. Default: "
"/tmp",
)
args = parser.parse_args()
######## Configure workflow ########
workflow = CRFEvaluateWorkflow()
# Validation parameters
workflow.batch_size = args.batch_size
workflow.num_workers = args.num_workers
workflow.seed = args.seed
workflow.model_config = args.model_config
workflow.ckpt_dir = args.ckpt_dir
# Data
workflow.dataset = args.dataset
workflow.prediction_dir = args.prediction_dir
workflow.dtype = args.dtype
workflow.class_weight = args.class_weight
# Save
workflow.output_dir = args.output_dir
workflow.experiment_name = args.experiment_name
workflow.result_dir = args.result_dir
workflow.tmp_dir = args.tmp_dir
# Args
workflow.args = args
######## Model Run ########
workflow.run()