diff --git a/libavfilter/dnn/dnn_backend_native.c b/libavfilter/dnn/dnn_backend_native.c index 14e878b6b8..dc47c9b542 100644 --- a/libavfilter/dnn/dnn_backend_native.c +++ b/libavfilter/dnn/dnn_backend_native.c @@ -70,64 +70,6 @@ static DNNReturnType get_input_native(void *model, DNNData *input, const char *i return DNN_ERROR; } -static DNNReturnType set_input_native(void *model, AVFrame *frame, const char *input_name) -{ - NativeModel *native_model = (NativeModel *)model; - NativeContext *ctx = &native_model->ctx; - DnnOperand *oprd = NULL; - DNNData input; - - if (native_model->layers_num <= 0 || native_model->operands_num <= 0) { - av_log(ctx, AV_LOG_ERROR, "No operands or layers in model\n"); - return DNN_ERROR; - } - - /* inputs */ - for (int i = 0; i < native_model->operands_num; ++i) { - oprd = &native_model->operands[i]; - if (strcmp(oprd->name, input_name) == 0) { - if (oprd->type != DOT_INPUT) { - av_log(ctx, AV_LOG_ERROR, "Found \"%s\" in model, but it is not input node\n", input_name); - return DNN_ERROR; - } - break; - } - oprd = NULL; - } - if (!oprd) { - av_log(ctx, AV_LOG_ERROR, "Could not find \"%s\" in model\n", input_name); - return DNN_ERROR; - } - - oprd->dims[1] = frame->height; - oprd->dims[2] = frame->width; - - av_freep(&oprd->data); - oprd->length = calculate_operand_data_length(oprd); - if (oprd->length <= 0) { - av_log(ctx, AV_LOG_ERROR, "The input data length overflow\n"); - return DNN_ERROR; - } - oprd->data = av_malloc(oprd->length); - if (!oprd->data) { - av_log(ctx, AV_LOG_ERROR, "Failed to malloc memory for input data\n"); - return DNN_ERROR; - } - - input.height = oprd->dims[1]; - input.width = oprd->dims[2]; - input.channels = oprd->dims[3]; - input.data = oprd->data; - input.dt = oprd->data_type; - if (native_model->model->pre_proc != NULL) { - native_model->model->pre_proc(frame, &input, native_model->model->userdata); - } else { - proc_from_frame_to_dnn(frame, &input, ctx); - } - - return DNN_SUCCESS; -} - // Loads model and its parameters that are stored in a binary file with following structure: // layers_num,layer_type,layer_parameterss,layer_type,layer_parameters... // For CONV layer: activation_function, input_num, output_num, kernel_size, kernel, biases @@ -273,7 +215,6 @@ DNNModel *ff_dnn_load_model_native(const char *model_filename, const char *optio return NULL; } - model->set_input = &set_input_native; model->get_input = &get_input_native; model->userdata = userdata; @@ -285,26 +226,66 @@ fail: return NULL; } -DNNReturnType ff_dnn_execute_model_native(const DNNModel *model, const char **output_names, uint32_t nb_output, AVFrame *out_frame) +DNNReturnType ff_dnn_execute_model_native(const DNNModel *model, const char *input_name, AVFrame *in_frame, + const char **output_names, uint32_t nb_output, AVFrame *out_frame) { NativeModel *native_model = (NativeModel *)model->model; NativeContext *ctx = &native_model->ctx; int32_t layer; - DNNData output; + DNNData input, output; + DnnOperand *oprd = NULL; - if (nb_output != 1) { - // currently, the filter does not need multiple outputs, - // so we just pending the support until we really need it. - av_log(ctx, AV_LOG_ERROR, "do not support multiple outputs\n"); + if (native_model->layers_num <= 0 || native_model->operands_num <= 0) { + av_log(ctx, AV_LOG_ERROR, "No operands or layers in model\n"); return DNN_ERROR; } - if (native_model->layers_num <= 0 || native_model->operands_num <= 0) { - av_log(ctx, AV_LOG_ERROR, "No operands or layers in model\n"); + for (int i = 0; i < native_model->operands_num; ++i) { + oprd = &native_model->operands[i]; + if (strcmp(oprd->name, input_name) == 0) { + if (oprd->type != DOT_INPUT) { + av_log(ctx, AV_LOG_ERROR, "Found \"%s\" in model, but it is not input node\n", input_name); + return DNN_ERROR; + } + break; + } + oprd = NULL; + } + if (!oprd) { + av_log(ctx, AV_LOG_ERROR, "Could not find \"%s\" in model\n", input_name); + return DNN_ERROR; + } + + oprd->dims[1] = in_frame->height; + oprd->dims[2] = in_frame->width; + + av_freep(&oprd->data); + oprd->length = calculate_operand_data_length(oprd); + if (oprd->length <= 0) { + av_log(ctx, AV_LOG_ERROR, "The input data length overflow\n"); return DNN_ERROR; } - if (!native_model->operands[0].data) { - av_log(ctx, AV_LOG_ERROR, "Empty model input data\n"); + oprd->data = av_malloc(oprd->length); + if (!oprd->data) { + av_log(ctx, AV_LOG_ERROR, "Failed to malloc memory for input data\n"); + return DNN_ERROR; + } + + input.height = oprd->dims[1]; + input.width = oprd->dims[2]; + input.channels = oprd->dims[3]; + input.data = oprd->data; + input.dt = oprd->data_type; + if (native_model->model->pre_proc != NULL) { + native_model->model->pre_proc(in_frame, &input, native_model->model->userdata); + } else { + proc_from_frame_to_dnn(in_frame, &input, ctx); + } + + if (nb_output != 1) { + // currently, the filter does not need multiple outputs, + // so we just pending the support until we really need it. + av_log(ctx, AV_LOG_ERROR, "do not support multiple outputs\n"); return DNN_ERROR; } diff --git a/libavfilter/dnn/dnn_backend_native.h b/libavfilter/dnn/dnn_backend_native.h index 553438bd22..2f8d73fcf6 100644 --- a/libavfilter/dnn/dnn_backend_native.h +++ b/libavfilter/dnn/dnn_backend_native.h @@ -128,7 +128,8 @@ typedef struct NativeModel{ DNNModel *ff_dnn_load_model_native(const char *model_filename, const char *options, void *userdata); -DNNReturnType ff_dnn_execute_model_native(const DNNModel *model, const char **output_names, uint32_t nb_output, AVFrame *out_frame); +DNNReturnType ff_dnn_execute_model_native(const DNNModel *model, const char *input_name, AVFrame *in_frame, + const char **output_names, uint32_t nb_output, AVFrame *out_frame); void ff_dnn_free_model_native(DNNModel **model); diff --git a/libavfilter/dnn/dnn_backend_openvino.c b/libavfilter/dnn/dnn_backend_openvino.c index b1bad3f659..0dba1c1adc 100644 --- a/libavfilter/dnn/dnn_backend_openvino.c +++ b/libavfilter/dnn/dnn_backend_openvino.c @@ -48,7 +48,6 @@ typedef struct OVModel{ ie_network_t *network; ie_executable_network_t *exe_network; ie_infer_request_t *infer_request; - ie_blob_t *input_blob; } OVModel; #define APPEND_STRING(generated_string, iterate_string) \ @@ -133,49 +132,6 @@ static DNNReturnType get_input_ov(void *model, DNNData *input, const char *input return DNN_ERROR; } -static DNNReturnType set_input_ov(void *model, AVFrame *frame, const char *input_name) -{ - OVModel *ov_model = (OVModel *)model; - OVContext *ctx = &ov_model->ctx; - IEStatusCode status; - dimensions_t dims; - precision_e precision; - ie_blob_buffer_t blob_buffer; - DNNData input; - - status = ie_infer_request_get_blob(ov_model->infer_request, input_name, &ov_model->input_blob); - if (status != OK) - goto err; - - status |= ie_blob_get_dims(ov_model->input_blob, &dims); - status |= ie_blob_get_precision(ov_model->input_blob, &precision); - if (status != OK) - goto err; - - status = ie_blob_get_buffer(ov_model->input_blob, &blob_buffer); - if (status != OK) - goto err; - - input.height = dims.dims[2]; - input.width = dims.dims[3]; - input.channels = dims.dims[1]; - input.data = blob_buffer.buffer; - input.dt = precision_to_datatype(precision); - if (ov_model->model->pre_proc != NULL) { - ov_model->model->pre_proc(frame, &input, ov_model->model->userdata); - } else { - proc_from_frame_to_dnn(frame, &input, ctx); - } - - return DNN_SUCCESS; - -err: - if (ov_model->input_blob) - ie_blob_free(&ov_model->input_blob); - av_log(ctx, AV_LOG_ERROR, "Failed to create inference instance or get input data/dims/precision/memory\n"); - return DNN_ERROR; -} - DNNModel *ff_dnn_load_model_ov(const char *model_filename, const char *options, void *userdata) { char *all_dev_names = NULL; @@ -234,7 +190,6 @@ DNNModel *ff_dnn_load_model_ov(const char *model_filename, const char *options, goto err; model->model = (void *)ov_model; - model->set_input = &set_input_ov; model->get_input = &get_input_ov; model->options = options; model->userdata = userdata; @@ -258,7 +213,8 @@ err: return NULL; } -DNNReturnType ff_dnn_execute_model_ov(const DNNModel *model, const char **output_names, uint32_t nb_output, AVFrame *out_frame) +DNNReturnType ff_dnn_execute_model_ov(const DNNModel *model, const char *input_name, AVFrame *in_frame, + const char **output_names, uint32_t nb_output, AVFrame *out_frame) { char *model_output_name = NULL; char *all_output_names = NULL; @@ -269,7 +225,39 @@ DNNReturnType ff_dnn_execute_model_ov(const DNNModel *model, const char **output OVContext *ctx = &ov_model->ctx; IEStatusCode status; size_t model_output_count = 0; - DNNData output; + DNNData input, output; + ie_blob_t *input_blob = NULL; + + status = ie_infer_request_get_blob(ov_model->infer_request, input_name, &input_blob); + if (status != OK) { + av_log(ctx, AV_LOG_ERROR, "Failed to get input blob\n"); + return DNN_ERROR; + } + + status |= ie_blob_get_dims(input_blob, &dims); + status |= ie_blob_get_precision(input_blob, &precision); + if (status != OK) { + av_log(ctx, AV_LOG_ERROR, "Failed to get input blob dims/precision\n"); + return DNN_ERROR; + } + + status = ie_blob_get_buffer(input_blob, &blob_buffer); + if (status != OK) { + av_log(ctx, AV_LOG_ERROR, "Failed to get input blob buffer\n"); + return DNN_ERROR; + } + + input.height = dims.dims[2]; + input.width = dims.dims[3]; + input.channels = dims.dims[1]; + input.data = blob_buffer.buffer; + input.dt = precision_to_datatype(precision); + if (ov_model->model->pre_proc != NULL) { + ov_model->model->pre_proc(in_frame, &input, ov_model->model->userdata); + } else { + proc_from_frame_to_dnn(in_frame, &input, ctx); + } + ie_blob_free(&input_blob); if (nb_output != 1) { // currently, the filter does not need multiple outputs, @@ -330,6 +318,7 @@ DNNReturnType ff_dnn_execute_model_ov(const DNNModel *model, const char **output proc_from_dnn_to_frame(out_frame, &output, ctx); } } + ie_blob_free(&output_blob); } return DNN_SUCCESS; @@ -339,8 +328,6 @@ void ff_dnn_free_model_ov(DNNModel **model) { if (*model){ OVModel *ov_model = (OVModel *)(*model)->model; - if (ov_model->input_blob) - ie_blob_free(&ov_model->input_blob); if (ov_model->infer_request) ie_infer_request_free(&ov_model->infer_request); if (ov_model->exe_network) diff --git a/libavfilter/dnn/dnn_backend_openvino.h b/libavfilter/dnn/dnn_backend_openvino.h index efb349cb49..3f8f01da60 100644 --- a/libavfilter/dnn/dnn_backend_openvino.h +++ b/libavfilter/dnn/dnn_backend_openvino.h @@ -31,7 +31,8 @@ DNNModel *ff_dnn_load_model_ov(const char *model_filename, const char *options, void *userdata); -DNNReturnType ff_dnn_execute_model_ov(const DNNModel *model, const char **output_names, uint32_t nb_output, AVFrame *out_frame); +DNNReturnType ff_dnn_execute_model_ov(const DNNModel *model, const char *input_name, AVFrame *in_frame, + const char **output_names, uint32_t nb_output, AVFrame *out_frame); void ff_dnn_free_model_ov(DNNModel **model); diff --git a/libavfilter/dnn/dnn_backend_tf.c b/libavfilter/dnn/dnn_backend_tf.c index c2d8c06931..8467f8a459 100644 --- a/libavfilter/dnn/dnn_backend_tf.c +++ b/libavfilter/dnn/dnn_backend_tf.c @@ -45,8 +45,6 @@ typedef struct TFModel{ TF_Graph *graph; TF_Session *session; TF_Status *status; - TF_Output input; - TF_Tensor *input_tensor; } TFModel; static const AVClass dnn_tensorflow_class = { @@ -152,48 +150,33 @@ static DNNReturnType get_input_tf(void *model, DNNData *input, const char *input return DNN_SUCCESS; } -static DNNReturnType set_input_tf(void *model, AVFrame *frame, const char *input_name) +static DNNReturnType load_tf_model(TFModel *tf_model, const char *model_filename) { - TFModel *tf_model = (TFModel *)model; TFContext *ctx = &tf_model->ctx; - DNNData input; + TF_Buffer *graph_def; + TF_ImportGraphDefOptions *graph_opts; TF_SessionOptions *sess_opts; - const TF_Operation *init_op = TF_GraphOperationByName(tf_model->graph, "init"); - - if (get_input_tf(model, &input, input_name) != DNN_SUCCESS) - return DNN_ERROR; - input.height = frame->height; - input.width = frame->width; + const TF_Operation *init_op; - // Input operation - tf_model->input.oper = TF_GraphOperationByName(tf_model->graph, input_name); - if (!tf_model->input.oper){ - av_log(ctx, AV_LOG_ERROR, "Could not find \"%s\" in model\n", input_name); + graph_def = read_graph(model_filename); + if (!graph_def){ + av_log(ctx, AV_LOG_ERROR, "Failed to read model \"%s\" graph\n", model_filename); return DNN_ERROR; } - tf_model->input.index = 0; - if (tf_model->input_tensor){ - TF_DeleteTensor(tf_model->input_tensor); - } - tf_model->input_tensor = allocate_input_tensor(&input); - if (!tf_model->input_tensor){ - av_log(ctx, AV_LOG_ERROR, "Failed to allocate memory for input tensor\n"); + tf_model->graph = TF_NewGraph(); + tf_model->status = TF_NewStatus(); + graph_opts = TF_NewImportGraphDefOptions(); + TF_GraphImportGraphDef(tf_model->graph, graph_def, graph_opts, tf_model->status); + TF_DeleteImportGraphDefOptions(graph_opts); + TF_DeleteBuffer(graph_def); + if (TF_GetCode(tf_model->status) != TF_OK){ + TF_DeleteGraph(tf_model->graph); + TF_DeleteStatus(tf_model->status); + av_log(ctx, AV_LOG_ERROR, "Failed to import serialized graph to model graph\n"); return DNN_ERROR; } - input.data = (float *)TF_TensorData(tf_model->input_tensor); - - if (tf_model->model->pre_proc != NULL) { - tf_model->model->pre_proc(frame, &input, tf_model->model->userdata); - } else { - proc_from_frame_to_dnn(frame, &input, ctx); - } - - // session - if (tf_model->session){ - TF_CloseSession(tf_model->session, tf_model->status); - TF_DeleteSession(tf_model->session, tf_model->status); - } + init_op = TF_GraphOperationByName(tf_model->graph, "init"); sess_opts = TF_NewSessionOptions(); tf_model->session = TF_NewSession(tf_model->graph, sess_opts, tf_model->status); TF_DeleteSessionOptions(sess_opts); @@ -219,33 +202,6 @@ static DNNReturnType set_input_tf(void *model, AVFrame *frame, const char *input return DNN_SUCCESS; } -static DNNReturnType load_tf_model(TFModel *tf_model, const char *model_filename) -{ - TFContext *ctx = &tf_model->ctx; - TF_Buffer *graph_def; - TF_ImportGraphDefOptions *graph_opts; - - graph_def = read_graph(model_filename); - if (!graph_def){ - av_log(ctx, AV_LOG_ERROR, "Failed to read model \"%s\" graph\n", model_filename); - return DNN_ERROR; - } - tf_model->graph = TF_NewGraph(); - tf_model->status = TF_NewStatus(); - graph_opts = TF_NewImportGraphDefOptions(); - TF_GraphImportGraphDef(tf_model->graph, graph_def, graph_opts, tf_model->status); - TF_DeleteImportGraphDefOptions(graph_opts); - TF_DeleteBuffer(graph_def); - if (TF_GetCode(tf_model->status) != TF_OK){ - TF_DeleteGraph(tf_model->graph); - TF_DeleteStatus(tf_model->status); - av_log(ctx, AV_LOG_ERROR, "Failed to import serialized graph to model graph\n"); - return DNN_ERROR; - } - - return DNN_SUCCESS; -} - #define NAME_BUFFER_SIZE 256 static DNNReturnType add_conv_layer(TFModel *tf_model, TF_Operation *transpose_op, TF_Operation **cur_op, @@ -626,7 +582,6 @@ DNNModel *ff_dnn_load_model_tf(const char *model_filename, const char *options, } model->model = (void *)tf_model; - model->set_input = &set_input_tf; model->get_input = &get_input_tf; model->options = options; model->userdata = userdata; @@ -634,13 +589,40 @@ DNNModel *ff_dnn_load_model_tf(const char *model_filename, const char *options, return model; } -DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, const char **output_names, uint32_t nb_output, AVFrame *out_frame) +DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, const char *input_name, AVFrame *in_frame, + const char **output_names, uint32_t nb_output, AVFrame *out_frame) { TF_Output *tf_outputs; TFModel *tf_model = (TFModel *)model->model; TFContext *ctx = &tf_model->ctx; - DNNData output; + DNNData input, output; TF_Tensor **output_tensors; + TF_Output tf_input; + TF_Tensor *input_tensor; + + if (get_input_tf(tf_model, &input, input_name) != DNN_SUCCESS) + return DNN_ERROR; + input.height = in_frame->height; + input.width = in_frame->width; + + tf_input.oper = TF_GraphOperationByName(tf_model->graph, input_name); + if (!tf_input.oper){ + av_log(ctx, AV_LOG_ERROR, "Could not find \"%s\" in model\n", input_name); + return DNN_ERROR; + } + tf_input.index = 0; + input_tensor = allocate_input_tensor(&input); + if (!input_tensor){ + av_log(ctx, AV_LOG_ERROR, "Failed to allocate memory for input tensor\n"); + return DNN_ERROR; + } + input.data = (float *)TF_TensorData(input_tensor); + + if (tf_model->model->pre_proc != NULL) { + tf_model->model->pre_proc(in_frame, &input, tf_model->model->userdata); + } else { + proc_from_frame_to_dnn(in_frame, &input, ctx); + } if (nb_output != 1) { // currently, the filter does not need multiple outputs, @@ -674,7 +656,7 @@ DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, const char **output } TF_SessionRun(tf_model->session, NULL, - &tf_model->input, &tf_model->input_tensor, 1, + &tf_input, &input_tensor, 1, tf_outputs, output_tensors, nb_output, NULL, 0, NULL, tf_model->status); if (TF_GetCode(tf_model->status) != TF_OK) { @@ -708,6 +690,7 @@ DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, const char **output TF_DeleteTensor(output_tensors[i]); } } + TF_DeleteTensor(input_tensor); av_freep(&output_tensors); av_freep(&tf_outputs); return DNN_SUCCESS; @@ -729,9 +712,6 @@ void ff_dnn_free_model_tf(DNNModel **model) if (tf_model->status){ TF_DeleteStatus(tf_model->status); } - if (tf_model->input_tensor){ - TF_DeleteTensor(tf_model->input_tensor); - } av_freep(&tf_model); av_freep(model); } diff --git a/libavfilter/dnn/dnn_backend_tf.h b/libavfilter/dnn/dnn_backend_tf.h index f379e83d8d..1e00669736 100644 --- a/libavfilter/dnn/dnn_backend_tf.h +++ b/libavfilter/dnn/dnn_backend_tf.h @@ -31,7 +31,8 @@ DNNModel *ff_dnn_load_model_tf(const char *model_filename, const char *options, void *userdata); -DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, const char **output_names, uint32_t nb_output, AVFrame *out_frame); +DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, const char *input_name, AVFrame *in_frame, + const char **output_names, uint32_t nb_output, AVFrame *out_frame); void ff_dnn_free_model_tf(DNNModel **model); diff --git a/libavfilter/dnn_interface.h b/libavfilter/dnn_interface.h index 6debc50607..0369ee4f71 100644 --- a/libavfilter/dnn_interface.h +++ b/libavfilter/dnn_interface.h @@ -51,9 +51,6 @@ typedef struct DNNModel{ // Gets model input information // Just reuse struct DNNData here, actually the DNNData.data field is not needed. DNNReturnType (*get_input)(void *model, DNNData *input, const char *input_name); - // Sets model input. - // Should be called every time before model execution. - DNNReturnType (*set_input)(void *model, AVFrame *frame, const char *input_name); // set the pre process to transfer data from AVFrame to DNNData // the default implementation within DNN is used if it is not provided by the filter int (*pre_proc)(AVFrame *frame_in, DNNData *model_input, void *user_data); @@ -66,8 +63,9 @@ typedef struct DNNModel{ typedef struct DNNModule{ // Loads model and parameters from given file. Returns NULL if it is not possible. DNNModel *(*load_model)(const char *model_filename, const char *options, void *userdata); - // Executes model with specified output. Returns DNN_ERROR otherwise. - DNNReturnType (*execute_model)(const DNNModel *model, const char **output_names, uint32_t nb_output, AVFrame *out_frame); + // Executes model with specified input and output. Returns DNN_ERROR otherwise. + DNNReturnType (*execute_model)(const DNNModel *model, const char *input_name, AVFrame *in_frame, + const char **output_names, uint32_t nb_output, AVFrame *out_frame); // Frees memory allocated for model. void (*free_model)(DNNModel **model); } DNNModule; diff --git a/libavfilter/vf_derain.c b/libavfilter/vf_derain.c index a59cd6e941..77dd401263 100644 --- a/libavfilter/vf_derain.c +++ b/libavfilter/vf_derain.c @@ -80,13 +80,6 @@ static int filter_frame(AVFilterLink *inlink, AVFrame *in) const char *model_output_name = "y"; AVFrame *out; - dnn_result = (dr_context->model->set_input)(dr_context->model->model, in, "x"); - if (dnn_result != DNN_SUCCESS) { - av_log(ctx, AV_LOG_ERROR, "could not set input for the model\n"); - av_frame_free(&in); - return AVERROR(EIO); - } - out = ff_get_video_buffer(outlink, outlink->w, outlink->h); if (!out) { av_log(ctx, AV_LOG_ERROR, "could not allocate memory for output frame\n"); @@ -95,7 +88,7 @@ static int filter_frame(AVFilterLink *inlink, AVFrame *in) } av_frame_copy_props(out, in); - dnn_result = (dr_context->dnn_module->execute_model)(dr_context->model, &model_output_name, 1, out); + dnn_result = (dr_context->dnn_module->execute_model)(dr_context->model, "x", in, &model_output_name, 1, out); if (dnn_result != DNN_SUCCESS){ av_log(ctx, AV_LOG_ERROR, "failed to execute model\n"); av_frame_free(&in); diff --git a/libavfilter/vf_dnn_processing.c b/libavfilter/vf_dnn_processing.c index d7462bc828..2c8578c9b0 100644 --- a/libavfilter/vf_dnn_processing.c +++ b/libavfilter/vf_dnn_processing.c @@ -236,15 +236,11 @@ static int config_output(AVFilterLink *outlink) AVFrame *out = NULL; AVFrame *fake_in = ff_get_video_buffer(inlink, inlink->w, inlink->h); - result = (ctx->model->set_input)(ctx->model->model, fake_in, ctx->model_inputname); - if (result != DNN_SUCCESS) { - av_log(ctx, AV_LOG_ERROR, "could not set input for the model\n"); - return AVERROR(EIO); - } // have a try run in case that the dnn model resize the frame out = ff_get_video_buffer(inlink, inlink->w, inlink->h); - result = (ctx->dnn_module->execute_model)(ctx->model, (const char **)&ctx->model_outputname, 1, out); + result = (ctx->dnn_module->execute_model)(ctx->model, ctx->model_inputname, fake_in, + (const char **)&ctx->model_outputname, 1, out); if (result != DNN_SUCCESS){ av_log(ctx, AV_LOG_ERROR, "failed to execute model\n"); return AVERROR(EIO); @@ -293,13 +289,6 @@ static int filter_frame(AVFilterLink *inlink, AVFrame *in) DNNReturnType dnn_result; AVFrame *out; - dnn_result = (ctx->model->set_input)(ctx->model->model, in, ctx->model_inputname); - if (dnn_result != DNN_SUCCESS) { - av_log(ctx, AV_LOG_ERROR, "could not set input for the model\n"); - av_frame_free(&in); - return AVERROR(EIO); - } - out = ff_get_video_buffer(outlink, outlink->w, outlink->h); if (!out) { av_frame_free(&in); @@ -307,7 +296,8 @@ static int filter_frame(AVFilterLink *inlink, AVFrame *in) } av_frame_copy_props(out, in); - dnn_result = (ctx->dnn_module->execute_model)(ctx->model, (const char **)&ctx->model_outputname, 1, out); + dnn_result = (ctx->dnn_module->execute_model)(ctx->model, ctx->model_inputname, in, + (const char **)&ctx->model_outputname, 1, out); if (dnn_result != DNN_SUCCESS){ av_log(ctx, AV_LOG_ERROR, "failed to execute model\n"); av_frame_free(&in); diff --git a/libavfilter/vf_sr.c b/libavfilter/vf_sr.c index 2eda8c3219..72a3137262 100644 --- a/libavfilter/vf_sr.c +++ b/libavfilter/vf_sr.c @@ -114,16 +114,11 @@ static int config_output(AVFilterLink *outlink) AVFrame *out = NULL; const char *model_output_name = "y"; - AVFrame *fake_in = ff_get_video_buffer(inlink, inlink->w, inlink->h); - result = (ctx->model->set_input)(ctx->model->model, fake_in, "x"); - if (result != DNN_SUCCESS) { - av_log(context, AV_LOG_ERROR, "could not set input for the model\n"); - return AVERROR(EIO); - } - // have a try run in case that the dnn model resize the frame + AVFrame *fake_in = ff_get_video_buffer(inlink, inlink->w, inlink->h); out = ff_get_video_buffer(inlink, inlink->w, inlink->h); - result = (ctx->dnn_module->execute_model)(ctx->model, (const char **)&model_output_name, 1, out); + result = (ctx->dnn_module->execute_model)(ctx->model, "x", fake_in, + (const char **)&model_output_name, 1, out); if (result != DNN_SUCCESS){ av_log(context, AV_LOG_ERROR, "failed to execute loaded model\n"); return AVERROR(EIO); @@ -178,19 +173,13 @@ static int filter_frame(AVFilterLink *inlink, AVFrame *in) sws_scale(ctx->sws_pre_scale, (const uint8_t **)in->data, in->linesize, 0, in->height, out->data, out->linesize); - dnn_result = (ctx->model->set_input)(ctx->model->model, out, "x"); + dnn_result = (ctx->dnn_module->execute_model)(ctx->model, "x", out, + (const char **)&model_output_name, 1, out); } else { - dnn_result = (ctx->model->set_input)(ctx->model->model, in, "x"); - } - - if (dnn_result != DNN_SUCCESS) { - av_frame_free(&in); - av_frame_free(&out); - av_log(context, AV_LOG_ERROR, "could not set input for the model\n"); - return AVERROR(EIO); + dnn_result = (ctx->dnn_module->execute_model)(ctx->model, "x", in, + (const char **)&model_output_name, 1, out); } - dnn_result = (ctx->dnn_module->execute_model)(ctx->model, (const char **)&model_output_name, 1, out); if (dnn_result != DNN_SUCCESS){ av_log(ctx, AV_LOG_ERROR, "failed to execute loaded model\n"); av_frame_free(&in);