【发布时间】:2021-06-29 15:32:30
【问题描述】:
我正在尝试在此 Github 存储库中运行 model.py 代码。
我正在使用 Jupyter Notebook。我已经克隆了这个存储库并从这个目录(repo 被克隆的地方)启动了 jupyter notebook,这样在导入模块时就不会发生错误。
当我运行model.pyfile的这部分代码时:
def main():
"""
Load train/validation data set and train the model
"""
parser = argparse.ArgumentParser(description='Behavioral Cloning Training Program')
parser.add_argument('-d', help='data directory', dest='data_dir', type=str, default='data')
parser.add_argument('-t', help='test size fraction', dest='test_size', type=float, default=0.2)
parser.add_argument('-k', help='drop out probability', dest='keep_prob', type=float, default=0.5)
parser.add_argument('-n', help='number of epochs', dest='nb_epoch', type=int, default=10)
parser.add_argument('-s', help='samples per epoch', dest='samples_per_epoch', type=int, default=20000)
parser.add_argument('-b', help='batch size', dest='batch_size', type=int, default=40)
parser.add_argument('-o', help='save best models only', dest='save_best_only', type=s2b, default='true')
parser.add_argument('-l', help='learning rate', dest='learning_rate', type=float, default=1.0e-4)
args = parser.parse_args()
#print parameters
print('-' * 30)
print('Parameters')
print('-' * 30)
for key, value in vars(args).items():
print('{:<20} := {}'.format(key, value))
print('-' * 30)
#load data
data = load_data(args)
#build model
model = build_model(args)
#train model on data, it saves as model.h5
train_model(model, args, *data)
if __name__ == '__main__':
main()
它给出了以下错误:
<ipython-input-16-6e1430362122> in <module>
32
33 if __name__ == '__main__':
---> 34 main()
<ipython-input-16-6e1430362122> in main()
26 data = load_data(args)
27 #build model
---> 28 model = build_model(args)
29 #train model on data, it saves as model.h5
30 train_model(model, args, *data)
<ipython-input-6-b4a45377398f> in build_model(args)
21 model = Sequential()
22 model.add(Lambda(lambda x: x/127.5-1.0, input_shape=INPUT_SHAPE))
---> 23 model.add(Conv2D(24, 5, 5, activation='elu', subsample=(2, 2)))
24 model.add(Conv2D(36, 5, 5, activation='elu', subsample=(2, 2)))
25 model.add(Conv2D(48, 5, 5, activation='elu', subsample=(2, 2)))
~\anaconda3\lib\site-packages\tensorflow\python\keras\layers\convolutional.py in __init__(self, filters, kernel_size, strides, padding, data_format, dilation_rate, groups, activation, use_bias, kernel_initializer, bias_initializer, kernel_regularizer, bias_regularizer, activity_regularizer, kernel_constraint, bias_constraint, **kwargs)
651 bias_constraint=None,
652 **kwargs):
--> 653 super(Conv2D, self).__init__(
654 rank=2,
655 filters=filters,
~\anaconda3\lib\site-packages\tensorflow\python\keras\layers\convolutional.py in __init__(self, rank, filters, kernel_size, strides, padding, data_format, dilation_rate, groups, activation, use_bias, kernel_initializer, bias_initializer, kernel_regularizer, bias_regularizer, activity_regularizer, kernel_constraint, bias_constraint, trainable, name, conv_op, **kwargs)
132 conv_op=None,
133 **kwargs):
--> 134 super(Conv, self).__init__(
135 trainable=trainable,
136 name=name,
~\anaconda3\lib\site-packages\tensorflow\python\training\tracking\base.py in _method_wrapper(self, *args, **kwargs)
515 self._self_setattr_tracking = False # pylint: disable=protected-access
516 try:
--> 517 result = method(self, *args, **kwargs)
518 finally:
519 self._self_setattr_tracking = previous_value # pylint: disable=protected-access
~\anaconda3\lib\site-packages\tensorflow\python\keras\engine\base_layer.py in __init__(self, trainable, name, dtype, dynamic, **kwargs)
338 }
339 # Validate optional keyword arguments.
--> 340 generic_utils.validate_kwargs(kwargs, allowed_kwargs)
341
342 # Mutable properties
~\anaconda3\lib\site-packages\tensorflow\python\keras\utils\generic_utils.py in validate_kwargs(kwargs, allowed_kwargs, error_message)
806 for kwarg in kwargs:
807 if kwarg not in allowed_kwargs:
--> 808 raise TypeError(error_message, kwarg)
809
810
TypeError: ('Keyword argument not understood:', 'subsample')
我不知道这是什么意思。
P.S:我用args, unknown = parser.parse_known_args()替换了args = parser.parse_args()(只是告诉以防万一这是错误的原因)
【问题讨论】:
标签: python python-3.x tensorflow keras training-data