Currently, every filter needs to provide code to transfer data from AVFrame* to model input (DNNData*), and also from model output (DNNData*) to AVFrame*. Actually, such transfer can be implemented within DNN module, and so filter can focus on its own business logic. DNN module also exports the function pointer pre_proc and post_proc in struct DNNModel, just in case that a filter has its special logic to transfer data between AVFrame* and DNNData*. The default implementation within DNN module is used if the filter does not set pre/post_proc.tags/n4.4
@@ -2628,6 +2628,7 @@ cbs_vp9_select="cbs" | |||||
dct_select="rdft" | dct_select="rdft" | ||||
dirac_parse_select="golomb" | dirac_parse_select="golomb" | ||||
dnn_suggest="libtensorflow libopenvino" | dnn_suggest="libtensorflow libopenvino" | ||||
dnn_deps="swscale" | |||||
error_resilience_select="me_cmp" | error_resilience_select="me_cmp" | ||||
faandct_deps="faan" | faandct_deps="faan" | ||||
faandct_select="fdctdsp" | faandct_select="fdctdsp" | ||||
@@ -3532,7 +3533,6 @@ derain_filter_select="dnn" | |||||
deshake_filter_select="pixelutils" | deshake_filter_select="pixelutils" | ||||
deshake_opencl_filter_deps="opencl" | deshake_opencl_filter_deps="opencl" | ||||
dilation_opencl_filter_deps="opencl" | dilation_opencl_filter_deps="opencl" | ||||
dnn_processing_filter_deps="swscale" | |||||
dnn_processing_filter_select="dnn" | dnn_processing_filter_select="dnn" | ||||
drawtext_filter_deps="libfreetype" | drawtext_filter_deps="libfreetype" | ||||
drawtext_filter_suggest="libfontconfig libfribidi" | drawtext_filter_suggest="libfontconfig libfribidi" | ||||
@@ -1,4 +1,5 @@ | |||||
OBJS-$(CONFIG_DNN) += dnn/dnn_interface.o | OBJS-$(CONFIG_DNN) += dnn/dnn_interface.o | ||||
OBJS-$(CONFIG_DNN) += dnn/dnn_io_proc.o | |||||
OBJS-$(CONFIG_DNN) += dnn/dnn_backend_native.o | OBJS-$(CONFIG_DNN) += dnn/dnn_backend_native.o | ||||
OBJS-$(CONFIG_DNN) += dnn/dnn_backend_native_layers.o | OBJS-$(CONFIG_DNN) += dnn/dnn_backend_native_layers.o | ||||
OBJS-$(CONFIG_DNN) += dnn/dnn_backend_native_layer_avgpool.o | OBJS-$(CONFIG_DNN) += dnn/dnn_backend_native_layer_avgpool.o | ||||
@@ -27,6 +27,7 @@ | |||||
#include "libavutil/avassert.h" | #include "libavutil/avassert.h" | ||||
#include "dnn_backend_native_layer_conv2d.h" | #include "dnn_backend_native_layer_conv2d.h" | ||||
#include "dnn_backend_native_layers.h" | #include "dnn_backend_native_layers.h" | ||||
#include "dnn_io_proc.h" | |||||
#define OFFSET(x) offsetof(NativeContext, x) | #define OFFSET(x) offsetof(NativeContext, x) | ||||
#define FLAGS AV_OPT_FLAG_FILTERING_PARAM | #define FLAGS AV_OPT_FLAG_FILTERING_PARAM | ||||
@@ -69,11 +70,12 @@ static DNNReturnType get_input_native(void *model, DNNData *input, const char *i | |||||
return DNN_ERROR; | return DNN_ERROR; | ||||
} | } | ||||
static DNNReturnType set_input_native(void *model, DNNData *input, const char *input_name) | |||||
static DNNReturnType set_input_native(void *model, AVFrame *frame, const char *input_name) | |||||
{ | { | ||||
NativeModel *native_model = (NativeModel *)model; | NativeModel *native_model = (NativeModel *)model; | ||||
NativeContext *ctx = &native_model->ctx; | NativeContext *ctx = &native_model->ctx; | ||||
DnnOperand *oprd = NULL; | DnnOperand *oprd = NULL; | ||||
DNNData input; | |||||
if (native_model->layers_num <= 0 || native_model->operands_num <= 0) { | if (native_model->layers_num <= 0 || native_model->operands_num <= 0) { | ||||
av_log(ctx, AV_LOG_ERROR, "No operands or layers in model\n"); | av_log(ctx, AV_LOG_ERROR, "No operands or layers in model\n"); | ||||
@@ -97,10 +99,8 @@ static DNNReturnType set_input_native(void *model, DNNData *input, const char *i | |||||
return DNN_ERROR; | return DNN_ERROR; | ||||
} | } | ||||
oprd->dims[0] = 1; | |||||
oprd->dims[1] = input->height; | |||||
oprd->dims[2] = input->width; | |||||
oprd->dims[3] = input->channels; | |||||
oprd->dims[1] = frame->height; | |||||
oprd->dims[2] = frame->width; | |||||
av_freep(&oprd->data); | av_freep(&oprd->data); | ||||
oprd->length = calculate_operand_data_length(oprd); | oprd->length = calculate_operand_data_length(oprd); | ||||
@@ -114,7 +114,16 @@ static DNNReturnType set_input_native(void *model, DNNData *input, const char *i | |||||
return DNN_ERROR; | return DNN_ERROR; | ||||
} | } | ||||
input->data = oprd->data; | |||||
input.height = oprd->dims[1]; | |||||
input.width = oprd->dims[2]; | |||||
input.channels = oprd->dims[3]; | |||||
input.data = oprd->data; | |||||
input.dt = oprd->data_type; | |||||
if (native_model->model->pre_proc != NULL) { | |||||
native_model->model->pre_proc(frame, &input, native_model->model->userdata); | |||||
} else { | |||||
proc_from_frame_to_dnn(frame, &input, ctx); | |||||
} | |||||
return DNN_SUCCESS; | return DNN_SUCCESS; | ||||
} | } | ||||
@@ -185,6 +194,7 @@ DNNModel *ff_dnn_load_model_native(const char *model_filename, const char *optio | |||||
if (av_opt_set_from_string(&native_model->ctx, model->options, NULL, "=", "&") < 0) | if (av_opt_set_from_string(&native_model->ctx, model->options, NULL, "=", "&") < 0) | ||||
goto fail; | goto fail; | ||||
model->model = (void *)native_model; | model->model = (void *)native_model; | ||||
native_model->model = model; | |||||
#if !HAVE_PTHREAD_CANCEL | #if !HAVE_PTHREAD_CANCEL | ||||
if (native_model->ctx.options.conv2d_threads > 1){ | if (native_model->ctx.options.conv2d_threads > 1){ | ||||
@@ -275,11 +285,19 @@ fail: | |||||
return NULL; | return NULL; | ||||
} | } | ||||
DNNReturnType ff_dnn_execute_model_native(const DNNModel *model, DNNData *outputs, const char **output_names, uint32_t nb_output) | |||||
DNNReturnType ff_dnn_execute_model_native(const DNNModel *model, const char **output_names, uint32_t nb_output, AVFrame *out_frame) | |||||
{ | { | ||||
NativeModel *native_model = (NativeModel *)model->model; | NativeModel *native_model = (NativeModel *)model->model; | ||||
NativeContext *ctx = &native_model->ctx; | NativeContext *ctx = &native_model->ctx; | ||||
int32_t layer; | int32_t layer; | ||||
DNNData output; | |||||
if (nb_output != 1) { | |||||
// currently, the filter does not need multiple outputs, | |||||
// so we just pending the support until we really need it. | |||||
av_log(ctx, AV_LOG_ERROR, "do not support multiple outputs\n"); | |||||
return DNN_ERROR; | |||||
} | |||||
if (native_model->layers_num <= 0 || native_model->operands_num <= 0) { | if (native_model->layers_num <= 0 || native_model->operands_num <= 0) { | ||||
av_log(ctx, AV_LOG_ERROR, "No operands or layers in model\n"); | av_log(ctx, AV_LOG_ERROR, "No operands or layers in model\n"); | ||||
@@ -317,11 +335,22 @@ DNNReturnType ff_dnn_execute_model_native(const DNNModel *model, DNNData *output | |||||
return DNN_ERROR; | return DNN_ERROR; | ||||
} | } | ||||
outputs[i].data = oprd->data; | |||||
outputs[i].height = oprd->dims[1]; | |||||
outputs[i].width = oprd->dims[2]; | |||||
outputs[i].channels = oprd->dims[3]; | |||||
outputs[i].dt = oprd->data_type; | |||||
output.data = oprd->data; | |||||
output.height = oprd->dims[1]; | |||||
output.width = oprd->dims[2]; | |||||
output.channels = oprd->dims[3]; | |||||
output.dt = oprd->data_type; | |||||
if (out_frame->width != output.width || out_frame->height != output.height) { | |||||
out_frame->width = output.width; | |||||
out_frame->height = output.height; | |||||
} else { | |||||
if (native_model->model->post_proc != NULL) { | |||||
native_model->model->post_proc(out_frame, &output, native_model->model->userdata); | |||||
} else { | |||||
proc_from_dnn_to_frame(out_frame, &output, ctx); | |||||
} | |||||
} | |||||
} | } | ||||
return DNN_SUCCESS; | return DNN_SUCCESS; | ||||
@@ -119,6 +119,7 @@ typedef struct NativeContext { | |||||
// Represents simple feed-forward convolutional network. | // Represents simple feed-forward convolutional network. | ||||
typedef struct NativeModel{ | typedef struct NativeModel{ | ||||
NativeContext ctx; | NativeContext ctx; | ||||
DNNModel *model; | |||||
Layer *layers; | Layer *layers; | ||||
int32_t layers_num; | int32_t layers_num; | ||||
DnnOperand *operands; | DnnOperand *operands; | ||||
@@ -127,7 +128,7 @@ typedef struct NativeModel{ | |||||
DNNModel *ff_dnn_load_model_native(const char *model_filename, const char *options, void *userdata); | DNNModel *ff_dnn_load_model_native(const char *model_filename, const char *options, void *userdata); | ||||
DNNReturnType ff_dnn_execute_model_native(const DNNModel *model, DNNData *outputs, const char **output_names, uint32_t nb_output); | |||||
DNNReturnType ff_dnn_execute_model_native(const DNNModel *model, const char **output_names, uint32_t nb_output, AVFrame *out_frame); | |||||
void ff_dnn_free_model_native(DNNModel **model); | void ff_dnn_free_model_native(DNNModel **model); | ||||
@@ -24,6 +24,7 @@ | |||||
*/ | */ | ||||
#include "dnn_backend_openvino.h" | #include "dnn_backend_openvino.h" | ||||
#include "dnn_io_proc.h" | |||||
#include "libavformat/avio.h" | #include "libavformat/avio.h" | ||||
#include "libavutil/avassert.h" | #include "libavutil/avassert.h" | ||||
#include "libavutil/opt.h" | #include "libavutil/opt.h" | ||||
@@ -42,6 +43,7 @@ typedef struct OVContext { | |||||
typedef struct OVModel{ | typedef struct OVModel{ | ||||
OVContext ctx; | OVContext ctx; | ||||
DNNModel *model; | |||||
ie_core_t *core; | ie_core_t *core; | ||||
ie_network_t *network; | ie_network_t *network; | ||||
ie_executable_network_t *exe_network; | ie_executable_network_t *exe_network; | ||||
@@ -131,7 +133,7 @@ static DNNReturnType get_input_ov(void *model, DNNData *input, const char *input | |||||
return DNN_ERROR; | return DNN_ERROR; | ||||
} | } | ||||
static DNNReturnType set_input_ov(void *model, DNNData *input, const char *input_name) | |||||
static DNNReturnType set_input_ov(void *model, AVFrame *frame, const char *input_name) | |||||
{ | { | ||||
OVModel *ov_model = (OVModel *)model; | OVModel *ov_model = (OVModel *)model; | ||||
OVContext *ctx = &ov_model->ctx; | OVContext *ctx = &ov_model->ctx; | ||||
@@ -139,10 +141,7 @@ static DNNReturnType set_input_ov(void *model, DNNData *input, const char *input | |||||
dimensions_t dims; | dimensions_t dims; | ||||
precision_e precision; | precision_e precision; | ||||
ie_blob_buffer_t blob_buffer; | ie_blob_buffer_t blob_buffer; | ||||
status = ie_exec_network_create_infer_request(ov_model->exe_network, &ov_model->infer_request); | |||||
if (status != OK) | |||||
goto err; | |||||
DNNData input; | |||||
status = ie_infer_request_get_blob(ov_model->infer_request, input_name, &ov_model->input_blob); | status = ie_infer_request_get_blob(ov_model->infer_request, input_name, &ov_model->input_blob); | ||||
if (status != OK) | if (status != OK) | ||||
@@ -153,23 +152,26 @@ static DNNReturnType set_input_ov(void *model, DNNData *input, const char *input | |||||
if (status != OK) | if (status != OK) | ||||
goto err; | goto err; | ||||
av_assert0(input->channels == dims.dims[1]); | |||||
av_assert0(input->height == dims.dims[2]); | |||||
av_assert0(input->width == dims.dims[3]); | |||||
av_assert0(input->dt == precision_to_datatype(precision)); | |||||
status = ie_blob_get_buffer(ov_model->input_blob, &blob_buffer); | status = ie_blob_get_buffer(ov_model->input_blob, &blob_buffer); | ||||
if (status != OK) | if (status != OK) | ||||
goto err; | goto err; | ||||
input->data = blob_buffer.buffer; | |||||
input.height = dims.dims[2]; | |||||
input.width = dims.dims[3]; | |||||
input.channels = dims.dims[1]; | |||||
input.data = blob_buffer.buffer; | |||||
input.dt = precision_to_datatype(precision); | |||||
if (ov_model->model->pre_proc != NULL) { | |||||
ov_model->model->pre_proc(frame, &input, ov_model->model->userdata); | |||||
} else { | |||||
proc_from_frame_to_dnn(frame, &input, ctx); | |||||
} | |||||
return DNN_SUCCESS; | return DNN_SUCCESS; | ||||
err: | err: | ||||
if (ov_model->input_blob) | if (ov_model->input_blob) | ||||
ie_blob_free(&ov_model->input_blob); | ie_blob_free(&ov_model->input_blob); | ||||
if (ov_model->infer_request) | |||||
ie_infer_request_free(&ov_model->infer_request); | |||||
av_log(ctx, AV_LOG_ERROR, "Failed to create inference instance or get input data/dims/precision/memory\n"); | av_log(ctx, AV_LOG_ERROR, "Failed to create inference instance or get input data/dims/precision/memory\n"); | ||||
return DNN_ERROR; | return DNN_ERROR; | ||||
} | } | ||||
@@ -184,7 +186,7 @@ DNNModel *ff_dnn_load_model_ov(const char *model_filename, const char *options, | |||||
ie_config_t config = {NULL, NULL, NULL}; | ie_config_t config = {NULL, NULL, NULL}; | ||||
ie_available_devices_t a_dev; | ie_available_devices_t a_dev; | ||||
model = av_malloc(sizeof(DNNModel)); | |||||
model = av_mallocz(sizeof(DNNModel)); | |||||
if (!model){ | if (!model){ | ||||
return NULL; | return NULL; | ||||
} | } | ||||
@@ -192,6 +194,7 @@ DNNModel *ff_dnn_load_model_ov(const char *model_filename, const char *options, | |||||
ov_model = av_mallocz(sizeof(OVModel)); | ov_model = av_mallocz(sizeof(OVModel)); | ||||
if (!ov_model) | if (!ov_model) | ||||
goto err; | goto err; | ||||
ov_model->model = model; | |||||
ov_model->ctx.class = &dnn_openvino_class; | ov_model->ctx.class = &dnn_openvino_class; | ||||
ctx = &ov_model->ctx; | ctx = &ov_model->ctx; | ||||
@@ -226,6 +229,10 @@ DNNModel *ff_dnn_load_model_ov(const char *model_filename, const char *options, | |||||
goto err; | goto err; | ||||
} | } | ||||
status = ie_exec_network_create_infer_request(ov_model->exe_network, &ov_model->infer_request); | |||||
if (status != OK) | |||||
goto err; | |||||
model->model = (void *)ov_model; | model->model = (void *)ov_model; | ||||
model->set_input = &set_input_ov; | model->set_input = &set_input_ov; | ||||
model->get_input = &get_input_ov; | model->get_input = &get_input_ov; | ||||
@@ -238,6 +245,8 @@ err: | |||||
if (model) | if (model) | ||||
av_freep(&model); | av_freep(&model); | ||||
if (ov_model) { | if (ov_model) { | ||||
if (ov_model->infer_request) | |||||
ie_infer_request_free(&ov_model->infer_request); | |||||
if (ov_model->exe_network) | if (ov_model->exe_network) | ||||
ie_exec_network_free(&ov_model->exe_network); | ie_exec_network_free(&ov_model->exe_network); | ||||
if (ov_model->network) | if (ov_model->network) | ||||
@@ -249,7 +258,7 @@ err: | |||||
return NULL; | return NULL; | ||||
} | } | ||||
DNNReturnType ff_dnn_execute_model_ov(const DNNModel *model, DNNData *outputs, const char **output_names, uint32_t nb_output) | |||||
DNNReturnType ff_dnn_execute_model_ov(const DNNModel *model, const char **output_names, uint32_t nb_output, AVFrame *out_frame) | |||||
{ | { | ||||
char *model_output_name = NULL; | char *model_output_name = NULL; | ||||
char *all_output_names = NULL; | char *all_output_names = NULL; | ||||
@@ -258,8 +267,18 @@ DNNReturnType ff_dnn_execute_model_ov(const DNNModel *model, DNNData *outputs, c | |||||
ie_blob_buffer_t blob_buffer; | ie_blob_buffer_t blob_buffer; | ||||
OVModel *ov_model = (OVModel *)model->model; | OVModel *ov_model = (OVModel *)model->model; | ||||
OVContext *ctx = &ov_model->ctx; | OVContext *ctx = &ov_model->ctx; | ||||
IEStatusCode status = ie_infer_request_infer(ov_model->infer_request); | |||||
IEStatusCode status; | |||||
size_t model_output_count = 0; | size_t model_output_count = 0; | ||||
DNNData output; | |||||
if (nb_output != 1) { | |||||
// currently, the filter does not need multiple outputs, | |||||
// so we just pending the support until we really need it. | |||||
av_log(ctx, AV_LOG_ERROR, "do not support multiple outputs\n"); | |||||
return DNN_ERROR; | |||||
} | |||||
status = ie_infer_request_infer(ov_model->infer_request); | |||||
if (status != OK) { | if (status != OK) { | ||||
av_log(ctx, AV_LOG_ERROR, "Failed to start synchronous model inference\n"); | av_log(ctx, AV_LOG_ERROR, "Failed to start synchronous model inference\n"); | ||||
return DNN_ERROR; | return DNN_ERROR; | ||||
@@ -296,11 +315,21 @@ DNNReturnType ff_dnn_execute_model_ov(const DNNModel *model, DNNData *outputs, c | |||||
return DNN_ERROR; | return DNN_ERROR; | ||||
} | } | ||||
outputs[i].channels = dims.dims[1]; | |||||
outputs[i].height = dims.dims[2]; | |||||
outputs[i].width = dims.dims[3]; | |||||
outputs[i].dt = precision_to_datatype(precision); | |||||
outputs[i].data = blob_buffer.buffer; | |||||
output.channels = dims.dims[1]; | |||||
output.height = dims.dims[2]; | |||||
output.width = dims.dims[3]; | |||||
output.dt = precision_to_datatype(precision); | |||||
output.data = blob_buffer.buffer; | |||||
if (out_frame->width != output.width || out_frame->height != output.height) { | |||||
out_frame->width = output.width; | |||||
out_frame->height = output.height; | |||||
} else { | |||||
if (ov_model->model->post_proc != NULL) { | |||||
ov_model->model->post_proc(out_frame, &output, ov_model->model->userdata); | |||||
} else { | |||||
proc_from_dnn_to_frame(out_frame, &output, ctx); | |||||
} | |||||
} | |||||
} | } | ||||
return DNN_SUCCESS; | return DNN_SUCCESS; | ||||
@@ -31,7 +31,7 @@ | |||||
DNNModel *ff_dnn_load_model_ov(const char *model_filename, const char *options, void *userdata); | DNNModel *ff_dnn_load_model_ov(const char *model_filename, const char *options, void *userdata); | ||||
DNNReturnType ff_dnn_execute_model_ov(const DNNModel *model, DNNData *outputs, const char **output_names, uint32_t nb_output); | |||||
DNNReturnType ff_dnn_execute_model_ov(const DNNModel *model, const char **output_names, uint32_t nb_output, AVFrame *out_frame); | |||||
void ff_dnn_free_model_ov(DNNModel **model); | void ff_dnn_free_model_ov(DNNModel **model); | ||||
@@ -31,6 +31,7 @@ | |||||
#include "libavutil/avassert.h" | #include "libavutil/avassert.h" | ||||
#include "dnn_backend_native_layer_pad.h" | #include "dnn_backend_native_layer_pad.h" | ||||
#include "dnn_backend_native_layer_maximum.h" | #include "dnn_backend_native_layer_maximum.h" | ||||
#include "dnn_io_proc.h" | |||||
#include <tensorflow/c/c_api.h> | #include <tensorflow/c/c_api.h> | ||||
@@ -40,13 +41,12 @@ typedef struct TFContext { | |||||
typedef struct TFModel{ | typedef struct TFModel{ | ||||
TFContext ctx; | TFContext ctx; | ||||
DNNModel *model; | |||||
TF_Graph *graph; | TF_Graph *graph; | ||||
TF_Session *session; | TF_Session *session; | ||||
TF_Status *status; | TF_Status *status; | ||||
TF_Output input; | TF_Output input; | ||||
TF_Tensor *input_tensor; | TF_Tensor *input_tensor; | ||||
TF_Tensor **output_tensors; | |||||
uint32_t nb_output; | |||||
} TFModel; | } TFModel; | ||||
static const AVClass dnn_tensorflow_class = { | static const AVClass dnn_tensorflow_class = { | ||||
@@ -152,13 +152,19 @@ static DNNReturnType get_input_tf(void *model, DNNData *input, const char *input | |||||
return DNN_SUCCESS; | return DNN_SUCCESS; | ||||
} | } | ||||
static DNNReturnType set_input_tf(void *model, DNNData *input, const char *input_name) | |||||
static DNNReturnType set_input_tf(void *model, AVFrame *frame, const char *input_name) | |||||
{ | { | ||||
TFModel *tf_model = (TFModel *)model; | TFModel *tf_model = (TFModel *)model; | ||||
TFContext *ctx = &tf_model->ctx; | TFContext *ctx = &tf_model->ctx; | ||||
DNNData input; | |||||
TF_SessionOptions *sess_opts; | TF_SessionOptions *sess_opts; | ||||
const TF_Operation *init_op = TF_GraphOperationByName(tf_model->graph, "init"); | const TF_Operation *init_op = TF_GraphOperationByName(tf_model->graph, "init"); | ||||
if (get_input_tf(model, &input, input_name) != DNN_SUCCESS) | |||||
return DNN_ERROR; | |||||
input.height = frame->height; | |||||
input.width = frame->width; | |||||
// Input operation | // Input operation | ||||
tf_model->input.oper = TF_GraphOperationByName(tf_model->graph, input_name); | tf_model->input.oper = TF_GraphOperationByName(tf_model->graph, input_name); | ||||
if (!tf_model->input.oper){ | if (!tf_model->input.oper){ | ||||
@@ -169,12 +175,18 @@ static DNNReturnType set_input_tf(void *model, DNNData *input, const char *input | |||||
if (tf_model->input_tensor){ | if (tf_model->input_tensor){ | ||||
TF_DeleteTensor(tf_model->input_tensor); | TF_DeleteTensor(tf_model->input_tensor); | ||||
} | } | ||||
tf_model->input_tensor = allocate_input_tensor(input); | |||||
tf_model->input_tensor = allocate_input_tensor(&input); | |||||
if (!tf_model->input_tensor){ | if (!tf_model->input_tensor){ | ||||
av_log(ctx, AV_LOG_ERROR, "Failed to allocate memory for input tensor\n"); | av_log(ctx, AV_LOG_ERROR, "Failed to allocate memory for input tensor\n"); | ||||
return DNN_ERROR; | return DNN_ERROR; | ||||
} | } | ||||
input->data = (float *)TF_TensorData(tf_model->input_tensor); | |||||
input.data = (float *)TF_TensorData(tf_model->input_tensor); | |||||
if (tf_model->model->pre_proc != NULL) { | |||||
tf_model->model->pre_proc(frame, &input, tf_model->model->userdata); | |||||
} else { | |||||
proc_from_frame_to_dnn(frame, &input, ctx); | |||||
} | |||||
// session | // session | ||||
if (tf_model->session){ | if (tf_model->session){ | ||||
@@ -591,7 +603,7 @@ DNNModel *ff_dnn_load_model_tf(const char *model_filename, const char *options, | |||||
DNNModel *model = NULL; | DNNModel *model = NULL; | ||||
TFModel *tf_model = NULL; | TFModel *tf_model = NULL; | ||||
model = av_malloc(sizeof(DNNModel)); | |||||
model = av_mallocz(sizeof(DNNModel)); | |||||
if (!model){ | if (!model){ | ||||
return NULL; | return NULL; | ||||
} | } | ||||
@@ -602,6 +614,7 @@ DNNModel *ff_dnn_load_model_tf(const char *model_filename, const char *options, | |||||
return NULL; | return NULL; | ||||
} | } | ||||
tf_model->ctx.class = &dnn_tensorflow_class; | tf_model->ctx.class = &dnn_tensorflow_class; | ||||
tf_model->model = model; | |||||
if (load_tf_model(tf_model, model_filename) != DNN_SUCCESS){ | if (load_tf_model(tf_model, model_filename) != DNN_SUCCESS){ | ||||
if (load_native_model(tf_model, model_filename) != DNN_SUCCESS){ | if (load_native_model(tf_model, model_filename) != DNN_SUCCESS){ | ||||
@@ -621,11 +634,20 @@ DNNModel *ff_dnn_load_model_tf(const char *model_filename, const char *options, | |||||
return model; | return model; | ||||
} | } | ||||
DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, DNNData *outputs, const char **output_names, uint32_t nb_output) | |||||
DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, const char **output_names, uint32_t nb_output, AVFrame *out_frame) | |||||
{ | { | ||||
TF_Output *tf_outputs; | TF_Output *tf_outputs; | ||||
TFModel *tf_model = (TFModel *)model->model; | TFModel *tf_model = (TFModel *)model->model; | ||||
TFContext *ctx = &tf_model->ctx; | TFContext *ctx = &tf_model->ctx; | ||||
DNNData output; | |||||
TF_Tensor **output_tensors; | |||||
if (nb_output != 1) { | |||||
// currently, the filter does not need multiple outputs, | |||||
// so we just pending the support until we really need it. | |||||
av_log(ctx, AV_LOG_ERROR, "do not support multiple outputs\n"); | |||||
return DNN_ERROR; | |||||
} | |||||
tf_outputs = av_malloc_array(nb_output, sizeof(*tf_outputs)); | tf_outputs = av_malloc_array(nb_output, sizeof(*tf_outputs)); | ||||
if (tf_outputs == NULL) { | if (tf_outputs == NULL) { | ||||
@@ -633,18 +655,8 @@ DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, DNNData *outputs, c | |||||
return DNN_ERROR; | return DNN_ERROR; | ||||
} | } | ||||
if (tf_model->output_tensors) { | |||||
for (uint32_t i = 0; i < tf_model->nb_output; ++i) { | |||||
if (tf_model->output_tensors[i]) { | |||||
TF_DeleteTensor(tf_model->output_tensors[i]); | |||||
tf_model->output_tensors[i] = NULL; | |||||
} | |||||
} | |||||
} | |||||
av_freep(&tf_model->output_tensors); | |||||
tf_model->nb_output = nb_output; | |||||
tf_model->output_tensors = av_mallocz_array(nb_output, sizeof(*tf_model->output_tensors)); | |||||
if (!tf_model->output_tensors) { | |||||
output_tensors = av_mallocz_array(nb_output, sizeof(*output_tensors)); | |||||
if (!output_tensors) { | |||||
av_freep(&tf_outputs); | av_freep(&tf_outputs); | ||||
av_log(ctx, AV_LOG_ERROR, "Failed to allocate memory for output tensor\n"); \ | av_log(ctx, AV_LOG_ERROR, "Failed to allocate memory for output tensor\n"); \ | ||||
return DNN_ERROR; | return DNN_ERROR; | ||||
@@ -654,6 +666,7 @@ DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, DNNData *outputs, c | |||||
tf_outputs[i].oper = TF_GraphOperationByName(tf_model->graph, output_names[i]); | tf_outputs[i].oper = TF_GraphOperationByName(tf_model->graph, output_names[i]); | ||||
if (!tf_outputs[i].oper) { | if (!tf_outputs[i].oper) { | ||||
av_freep(&tf_outputs); | av_freep(&tf_outputs); | ||||
av_freep(&output_tensors); | |||||
av_log(ctx, AV_LOG_ERROR, "Could not find output \"%s\" in model\n", output_names[i]); \ | av_log(ctx, AV_LOG_ERROR, "Could not find output \"%s\" in model\n", output_names[i]); \ | ||||
return DNN_ERROR; | return DNN_ERROR; | ||||
} | } | ||||
@@ -662,22 +675,40 @@ DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, DNNData *outputs, c | |||||
TF_SessionRun(tf_model->session, NULL, | TF_SessionRun(tf_model->session, NULL, | ||||
&tf_model->input, &tf_model->input_tensor, 1, | &tf_model->input, &tf_model->input_tensor, 1, | ||||
tf_outputs, tf_model->output_tensors, nb_output, | |||||
tf_outputs, output_tensors, nb_output, | |||||
NULL, 0, NULL, tf_model->status); | NULL, 0, NULL, tf_model->status); | ||||
if (TF_GetCode(tf_model->status) != TF_OK) { | if (TF_GetCode(tf_model->status) != TF_OK) { | ||||
av_freep(&tf_outputs); | av_freep(&tf_outputs); | ||||
av_freep(&output_tensors); | |||||
av_log(ctx, AV_LOG_ERROR, "Failed to run session when executing model\n"); | av_log(ctx, AV_LOG_ERROR, "Failed to run session when executing model\n"); | ||||
return DNN_ERROR; | return DNN_ERROR; | ||||
} | } | ||||
for (uint32_t i = 0; i < nb_output; ++i) { | for (uint32_t i = 0; i < nb_output; ++i) { | ||||
outputs[i].height = TF_Dim(tf_model->output_tensors[i], 1); | |||||
outputs[i].width = TF_Dim(tf_model->output_tensors[i], 2); | |||||
outputs[i].channels = TF_Dim(tf_model->output_tensors[i], 3); | |||||
outputs[i].data = TF_TensorData(tf_model->output_tensors[i]); | |||||
outputs[i].dt = TF_TensorType(tf_model->output_tensors[i]); | |||||
output.height = TF_Dim(output_tensors[i], 1); | |||||
output.width = TF_Dim(output_tensors[i], 2); | |||||
output.channels = TF_Dim(output_tensors[i], 3); | |||||
output.data = TF_TensorData(output_tensors[i]); | |||||
output.dt = TF_TensorType(output_tensors[i]); | |||||
if (out_frame->width != output.width || out_frame->height != output.height) { | |||||
out_frame->width = output.width; | |||||
out_frame->height = output.height; | |||||
} else { | |||||
if (tf_model->model->post_proc != NULL) { | |||||
tf_model->model->post_proc(out_frame, &output, tf_model->model->userdata); | |||||
} else { | |||||
proc_from_dnn_to_frame(out_frame, &output, ctx); | |||||
} | |||||
} | |||||
} | } | ||||
for (uint32_t i = 0; i < nb_output; ++i) { | |||||
if (output_tensors[i]) { | |||||
TF_DeleteTensor(output_tensors[i]); | |||||
} | |||||
} | |||||
av_freep(&output_tensors); | |||||
av_freep(&tf_outputs); | av_freep(&tf_outputs); | ||||
return DNN_SUCCESS; | return DNN_SUCCESS; | ||||
} | } | ||||
@@ -701,15 +732,6 @@ void ff_dnn_free_model_tf(DNNModel **model) | |||||
if (tf_model->input_tensor){ | if (tf_model->input_tensor){ | ||||
TF_DeleteTensor(tf_model->input_tensor); | TF_DeleteTensor(tf_model->input_tensor); | ||||
} | } | ||||
if (tf_model->output_tensors) { | |||||
for (uint32_t i = 0; i < tf_model->nb_output; ++i) { | |||||
if (tf_model->output_tensors[i]) { | |||||
TF_DeleteTensor(tf_model->output_tensors[i]); | |||||
tf_model->output_tensors[i] = NULL; | |||||
} | |||||
} | |||||
} | |||||
av_freep(&tf_model->output_tensors); | |||||
av_freep(&tf_model); | av_freep(&tf_model); | ||||
av_freep(model); | av_freep(model); | ||||
} | } | ||||
@@ -31,7 +31,7 @@ | |||||
DNNModel *ff_dnn_load_model_tf(const char *model_filename, const char *options, void *userdata); | DNNModel *ff_dnn_load_model_tf(const char *model_filename, const char *options, void *userdata); | ||||
DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, DNNData *outputs, const char **output_names, uint32_t nb_output); | |||||
DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, const char **output_names, uint32_t nb_output, AVFrame *out_frame); | |||||
void ff_dnn_free_model_tf(DNNModel **model); | void ff_dnn_free_model_tf(DNNModel **model); | ||||
@@ -0,0 +1,135 @@ | |||||
/* | |||||
* Copyright (c) 2020 | |||||
* | |||||
* This file is part of FFmpeg. | |||||
* | |||||
* FFmpeg is free software; you can redistribute it and/or | |||||
* modify it under the terms of the GNU Lesser General Public | |||||
* License as published by the Free Software Foundation; either | |||||
* version 2.1 of the License, or (at your option) any later version. | |||||
* | |||||
* FFmpeg is distributed in the hope that it will be useful, | |||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of | |||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU | |||||
* Lesser General Public License for more details. | |||||
* | |||||
* You should have received a copy of the GNU Lesser General Public | |||||
* License along with FFmpeg; if not, write to the Free Software | |||||
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA | |||||
*/ | |||||
#include "dnn_io_proc.h" | |||||
#include "libavutil/imgutils.h" | |||||
#include "libswscale/swscale.h" | |||||
DNNReturnType proc_from_dnn_to_frame(AVFrame *frame, DNNData *output, void *log_ctx) | |||||
{ | |||||
struct SwsContext *sws_ctx; | |||||
int bytewidth = av_image_get_linesize(frame->format, frame->width, 0); | |||||
if (output->dt != DNN_FLOAT) { | |||||
av_log(log_ctx, AV_LOG_ERROR, "do not support data type rather than DNN_FLOAT\n"); | |||||
return DNN_ERROR; | |||||
} | |||||
switch (frame->format) { | |||||
case AV_PIX_FMT_RGB24: | |||||
case AV_PIX_FMT_BGR24: | |||||
sws_ctx = sws_getContext(frame->width * 3, | |||||
frame->height, | |||||
AV_PIX_FMT_GRAYF32, | |||||
frame->width * 3, | |||||
frame->height, | |||||
AV_PIX_FMT_GRAY8, | |||||
0, NULL, NULL, NULL); | |||||
sws_scale(sws_ctx, (const uint8_t *[4]){(const uint8_t *)output->data, 0, 0, 0}, | |||||
(const int[4]){frame->width * 3 * sizeof(float), 0, 0, 0}, 0, frame->height, | |||||
(uint8_t * const*)frame->data, frame->linesize); | |||||
sws_freeContext(sws_ctx); | |||||
return DNN_SUCCESS; | |||||
case AV_PIX_FMT_GRAYF32: | |||||
av_image_copy_plane(frame->data[0], frame->linesize[0], | |||||
output->data, bytewidth, | |||||
bytewidth, frame->height); | |||||
return DNN_SUCCESS; | |||||
case AV_PIX_FMT_YUV420P: | |||||
case AV_PIX_FMT_YUV422P: | |||||
case AV_PIX_FMT_YUV444P: | |||||
case AV_PIX_FMT_YUV410P: | |||||
case AV_PIX_FMT_YUV411P: | |||||
case AV_PIX_FMT_GRAY8: | |||||
sws_ctx = sws_getContext(frame->width, | |||||
frame->height, | |||||
AV_PIX_FMT_GRAYF32, | |||||
frame->width, | |||||
frame->height, | |||||
AV_PIX_FMT_GRAY8, | |||||
0, NULL, NULL, NULL); | |||||
sws_scale(sws_ctx, (const uint8_t *[4]){(const uint8_t *)output->data, 0, 0, 0}, | |||||
(const int[4]){frame->width * sizeof(float), 0, 0, 0}, 0, frame->height, | |||||
(uint8_t * const*)frame->data, frame->linesize); | |||||
sws_freeContext(sws_ctx); | |||||
return DNN_SUCCESS; | |||||
default: | |||||
av_log(log_ctx, AV_LOG_ERROR, "do not support frame format %d\n", frame->format); | |||||
return DNN_ERROR; | |||||
} | |||||
return DNN_SUCCESS; | |||||
} | |||||
DNNReturnType proc_from_frame_to_dnn(AVFrame *frame, DNNData *input, void *log_ctx) | |||||
{ | |||||
struct SwsContext *sws_ctx; | |||||
int bytewidth = av_image_get_linesize(frame->format, frame->width, 0); | |||||
if (input->dt != DNN_FLOAT) { | |||||
av_log(log_ctx, AV_LOG_ERROR, "do not support data type rather than DNN_FLOAT\n"); | |||||
return DNN_ERROR; | |||||
} | |||||
switch (frame->format) { | |||||
case AV_PIX_FMT_RGB24: | |||||
case AV_PIX_FMT_BGR24: | |||||
sws_ctx = sws_getContext(frame->width * 3, | |||||
frame->height, | |||||
AV_PIX_FMT_GRAY8, | |||||
frame->width * 3, | |||||
frame->height, | |||||
AV_PIX_FMT_GRAYF32, | |||||
0, NULL, NULL, NULL); | |||||
sws_scale(sws_ctx, (const uint8_t **)frame->data, | |||||
frame->linesize, 0, frame->height, | |||||
(uint8_t * const*)(&input->data), | |||||
(const int [4]){frame->width * 3 * sizeof(float), 0, 0, 0}); | |||||
sws_freeContext(sws_ctx); | |||||
break; | |||||
case AV_PIX_FMT_GRAYF32: | |||||
av_image_copy_plane(input->data, bytewidth, | |||||
frame->data[0], frame->linesize[0], | |||||
bytewidth, frame->height); | |||||
break; | |||||
case AV_PIX_FMT_YUV420P: | |||||
case AV_PIX_FMT_YUV422P: | |||||
case AV_PIX_FMT_YUV444P: | |||||
case AV_PIX_FMT_YUV410P: | |||||
case AV_PIX_FMT_YUV411P: | |||||
case AV_PIX_FMT_GRAY8: | |||||
sws_ctx = sws_getContext(frame->width, | |||||
frame->height, | |||||
AV_PIX_FMT_GRAY8, | |||||
frame->width, | |||||
frame->height, | |||||
AV_PIX_FMT_GRAYF32, | |||||
0, NULL, NULL, NULL); | |||||
sws_scale(sws_ctx, (const uint8_t **)frame->data, | |||||
frame->linesize, 0, frame->height, | |||||
(uint8_t * const*)(&input->data), | |||||
(const int [4]){frame->width * sizeof(float), 0, 0, 0}); | |||||
sws_freeContext(sws_ctx); | |||||
break; | |||||
default: | |||||
av_log(log_ctx, AV_LOG_ERROR, "do not support frame format %d\n", frame->format); | |||||
return DNN_ERROR; | |||||
} | |||||
return DNN_SUCCESS; | |||||
} |
@@ -0,0 +1,36 @@ | |||||
/* | |||||
* Copyright (c) 2020 | |||||
* | |||||
* This file is part of FFmpeg. | |||||
* | |||||
* FFmpeg is free software; you can redistribute it and/or | |||||
* modify it under the terms of the GNU Lesser General Public | |||||
* License as published by the Free Software Foundation; either | |||||
* version 2.1 of the License, or (at your option) any later version. | |||||
* | |||||
* FFmpeg is distributed in the hope that it will be useful, | |||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of | |||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU | |||||
* Lesser General Public License for more details. | |||||
* | |||||
* You should have received a copy of the GNU Lesser General Public | |||||
* License along with FFmpeg; if not, write to the Free Software | |||||
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA | |||||
*/ | |||||
/** | |||||
* @file | |||||
* DNN input&output process between AVFrame and DNNData. | |||||
*/ | |||||
#ifndef AVFILTER_DNN_DNN_IO_PROC_H | |||||
#define AVFILTER_DNN_DNN_IO_PROC_H | |||||
#include "../dnn_interface.h" | |||||
#include "libavutil/frame.h" | |||||
DNNReturnType proc_from_frame_to_dnn(AVFrame *frame, DNNData *input, void *log_ctx); | |||||
DNNReturnType proc_from_dnn_to_frame(AVFrame *frame, DNNData *output, void *log_ctx); | |||||
#endif |
@@ -27,6 +27,7 @@ | |||||
#define AVFILTER_DNN_INTERFACE_H | #define AVFILTER_DNN_INTERFACE_H | ||||
#include <stdint.h> | #include <stdint.h> | ||||
#include "libavutil/frame.h" | |||||
typedef enum {DNN_SUCCESS, DNN_ERROR} DNNReturnType; | typedef enum {DNN_SUCCESS, DNN_ERROR} DNNReturnType; | ||||
@@ -50,17 +51,23 @@ typedef struct DNNModel{ | |||||
// Gets model input information | // Gets model input information | ||||
// Just reuse struct DNNData here, actually the DNNData.data field is not needed. | // Just reuse struct DNNData here, actually the DNNData.data field is not needed. | ||||
DNNReturnType (*get_input)(void *model, DNNData *input, const char *input_name); | DNNReturnType (*get_input)(void *model, DNNData *input, const char *input_name); | ||||
// Sets model input and output. | |||||
// Should be called at least once before model execution. | |||||
DNNReturnType (*set_input)(void *model, DNNData *input, const char *input_name); | |||||
// Sets model input. | |||||
// Should be called every time before model execution. | |||||
DNNReturnType (*set_input)(void *model, AVFrame *frame, const char *input_name); | |||||
// set the pre process to transfer data from AVFrame to DNNData | |||||
// the default implementation within DNN is used if it is not provided by the filter | |||||
int (*pre_proc)(AVFrame *frame_in, DNNData *model_input, void *user_data); | |||||
// set the post process to transfer data from DNNData to AVFrame | |||||
// the default implementation within DNN is used if it is not provided by the filter | |||||
int (*post_proc)(AVFrame *frame_out, DNNData *model_output, void *user_data); | |||||
} DNNModel; | } DNNModel; | ||||
// Stores pointers to functions for loading, executing, freeing DNN models for one of the backends. | // Stores pointers to functions for loading, executing, freeing DNN models for one of the backends. | ||||
typedef struct DNNModule{ | typedef struct DNNModule{ | ||||
// Loads model and parameters from given file. Returns NULL if it is not possible. | // Loads model and parameters from given file. Returns NULL if it is not possible. | ||||
DNNModel *(*load_model)(const char *model_filename, const char *options, void *userdata); | DNNModel *(*load_model)(const char *model_filename, const char *options, void *userdata); | ||||
// Executes model with specified input and output. Returns DNN_ERROR otherwise. | |||||
DNNReturnType (*execute_model)(const DNNModel *model, DNNData *outputs, const char **output_names, uint32_t nb_output); | |||||
// Executes model with specified output. Returns DNN_ERROR otherwise. | |||||
DNNReturnType (*execute_model)(const DNNModel *model, const char **output_names, uint32_t nb_output, AVFrame *out_frame); | |||||
// Frees memory allocated for model. | // Frees memory allocated for model. | ||||
void (*free_model)(DNNModel **model); | void (*free_model)(DNNModel **model); | ||||
} DNNModule; | } DNNModule; | ||||
@@ -39,11 +39,8 @@ typedef struct DRContext { | |||||
DNNBackendType backend_type; | DNNBackendType backend_type; | ||||
DNNModule *dnn_module; | DNNModule *dnn_module; | ||||
DNNModel *model; | DNNModel *model; | ||||
DNNData input; | |||||
DNNData output; | |||||
} DRContext; | } DRContext; | ||||
#define CLIP(x, min, max) (x < min ? min : (x > max ? max : x)) | |||||
#define OFFSET(x) offsetof(DRContext, x) | #define OFFSET(x) offsetof(DRContext, x) | ||||
#define FLAGS AV_OPT_FLAG_FILTERING_PARAM | AV_OPT_FLAG_VIDEO_PARAM | #define FLAGS AV_OPT_FLAG_FILTERING_PARAM | AV_OPT_FLAG_VIDEO_PARAM | ||||
static const AVOption derain_options[] = { | static const AVOption derain_options[] = { | ||||
@@ -74,25 +71,6 @@ static int query_formats(AVFilterContext *ctx) | |||||
return ff_set_common_formats(ctx, formats); | return ff_set_common_formats(ctx, formats); | ||||
} | } | ||||
static int config_inputs(AVFilterLink *inlink) | |||||
{ | |||||
AVFilterContext *ctx = inlink->dst; | |||||
DRContext *dr_context = ctx->priv; | |||||
DNNReturnType result; | |||||
dr_context->input.width = inlink->w; | |||||
dr_context->input.height = inlink->h; | |||||
dr_context->input.channels = 3; | |||||
result = (dr_context->model->set_input)(dr_context->model->model, &dr_context->input, "x"); | |||||
if (result != DNN_SUCCESS) { | |||||
av_log(ctx, AV_LOG_ERROR, "could not set input and output for the model\n"); | |||||
return AVERROR(EIO); | |||||
} | |||||
return 0; | |||||
} | |||||
static int filter_frame(AVFilterLink *inlink, AVFrame *in) | static int filter_frame(AVFilterLink *inlink, AVFrame *in) | ||||
{ | { | ||||
AVFilterContext *ctx = inlink->dst; | AVFilterContext *ctx = inlink->dst; | ||||
@@ -100,43 +78,30 @@ static int filter_frame(AVFilterLink *inlink, AVFrame *in) | |||||
DRContext *dr_context = ctx->priv; | DRContext *dr_context = ctx->priv; | ||||
DNNReturnType dnn_result; | DNNReturnType dnn_result; | ||||
const char *model_output_name = "y"; | const char *model_output_name = "y"; | ||||
AVFrame *out; | |||||
AVFrame *out = ff_get_video_buffer(outlink, outlink->w, outlink->h); | |||||
dnn_result = (dr_context->model->set_input)(dr_context->model->model, in, "x"); | |||||
if (dnn_result != DNN_SUCCESS) { | |||||
av_log(ctx, AV_LOG_ERROR, "could not set input for the model\n"); | |||||
av_frame_free(&in); | |||||
return AVERROR(EIO); | |||||
} | |||||
out = ff_get_video_buffer(outlink, outlink->w, outlink->h); | |||||
if (!out) { | if (!out) { | ||||
av_log(ctx, AV_LOG_ERROR, "could not allocate memory for output frame\n"); | av_log(ctx, AV_LOG_ERROR, "could not allocate memory for output frame\n"); | ||||
av_frame_free(&in); | av_frame_free(&in); | ||||
return AVERROR(ENOMEM); | return AVERROR(ENOMEM); | ||||
} | } | ||||
av_frame_copy_props(out, in); | av_frame_copy_props(out, in); | ||||
for (int i = 0; i < in->height; i++){ | |||||
for(int j = 0; j < in->width * 3; j++){ | |||||
int k = i * in->linesize[0] + j; | |||||
int t = i * in->width * 3 + j; | |||||
((float *)dr_context->input.data)[t] = in->data[0][k] / 255.0; | |||||
} | |||||
} | |||||
dnn_result = (dr_context->dnn_module->execute_model)(dr_context->model, &dr_context->output, &model_output_name, 1); | |||||
dnn_result = (dr_context->dnn_module->execute_model)(dr_context->model, &model_output_name, 1, out); | |||||
if (dnn_result != DNN_SUCCESS){ | if (dnn_result != DNN_SUCCESS){ | ||||
av_log(ctx, AV_LOG_ERROR, "failed to execute model\n"); | av_log(ctx, AV_LOG_ERROR, "failed to execute model\n"); | ||||
av_frame_free(&in); | |||||
return AVERROR(EIO); | return AVERROR(EIO); | ||||
} | } | ||||
out->height = dr_context->output.height; | |||||
out->width = dr_context->output.width; | |||||
outlink->h = dr_context->output.height; | |||||
outlink->w = dr_context->output.width; | |||||
for (int i = 0; i < out->height; i++){ | |||||
for(int j = 0; j < out->width * 3; j++){ | |||||
int k = i * out->linesize[0] + j; | |||||
int t = i * out->width * 3 + j; | |||||
out->data[0][k] = CLIP((int)((((float *)dr_context->output.data)[t]) * 255), 0, 255); | |||||
} | |||||
} | |||||
av_frame_free(&in); | av_frame_free(&in); | ||||
return ff_filter_frame(outlink, out); | return ff_filter_frame(outlink, out); | ||||
@@ -146,7 +111,6 @@ static av_cold int init(AVFilterContext *ctx) | |||||
{ | { | ||||
DRContext *dr_context = ctx->priv; | DRContext *dr_context = ctx->priv; | ||||
dr_context->input.dt = DNN_FLOAT; | |||||
dr_context->dnn_module = ff_get_dnn_module(dr_context->backend_type); | dr_context->dnn_module = ff_get_dnn_module(dr_context->backend_type); | ||||
if (!dr_context->dnn_module) { | if (!dr_context->dnn_module) { | ||||
av_log(ctx, AV_LOG_ERROR, "could not create DNN module for requested backend\n"); | av_log(ctx, AV_LOG_ERROR, "could not create DNN module for requested backend\n"); | ||||
@@ -184,7 +148,6 @@ static const AVFilterPad derain_inputs[] = { | |||||
{ | { | ||||
.name = "default", | .name = "default", | ||||
.type = AVMEDIA_TYPE_VIDEO, | .type = AVMEDIA_TYPE_VIDEO, | ||||
.config_props = config_inputs, | |||||
.filter_frame = filter_frame, | .filter_frame = filter_frame, | ||||
}, | }, | ||||
{ NULL } | { NULL } | ||||
@@ -46,12 +46,6 @@ typedef struct DnnProcessingContext { | |||||
DNNModule *dnn_module; | DNNModule *dnn_module; | ||||
DNNModel *model; | DNNModel *model; | ||||
// input & output of the model at execution time | |||||
DNNData input; | |||||
DNNData output; | |||||
struct SwsContext *sws_gray8_to_grayf32; | |||||
struct SwsContext *sws_grayf32_to_gray8; | |||||
struct SwsContext *sws_uv_scale; | struct SwsContext *sws_uv_scale; | ||||
int sws_uv_height; | int sws_uv_height; | ||||
} DnnProcessingContext; | } DnnProcessingContext; | ||||
@@ -103,7 +97,7 @@ static av_cold int init(AVFilterContext *context) | |||||
return AVERROR(EINVAL); | return AVERROR(EINVAL); | ||||
} | } | ||||
ctx->model = (ctx->dnn_module->load_model)(ctx->model_filename, ctx->backend_options, NULL); | |||||
ctx->model = (ctx->dnn_module->load_model)(ctx->model_filename, ctx->backend_options, ctx); | |||||
if (!ctx->model) { | if (!ctx->model) { | ||||
av_log(ctx, AV_LOG_ERROR, "could not load DNN model\n"); | av_log(ctx, AV_LOG_ERROR, "could not load DNN model\n"); | ||||
return AVERROR(EINVAL); | return AVERROR(EINVAL); | ||||
@@ -148,6 +142,10 @@ static int check_modelinput_inlink(const DNNData *model_input, const AVFilterLin | |||||
model_input->width, inlink->w); | model_input->width, inlink->w); | ||||
return AVERROR(EIO); | return AVERROR(EIO); | ||||
} | } | ||||
if (model_input->dt != DNN_FLOAT) { | |||||
av_log(ctx, AV_LOG_ERROR, "only support dnn models with input data type as float32.\n"); | |||||
return AVERROR(EIO); | |||||
} | |||||
switch (fmt) { | switch (fmt) { | ||||
case AV_PIX_FMT_RGB24: | case AV_PIX_FMT_RGB24: | ||||
@@ -156,20 +154,6 @@ static int check_modelinput_inlink(const DNNData *model_input, const AVFilterLin | |||||
LOG_FORMAT_CHANNEL_MISMATCH(); | LOG_FORMAT_CHANNEL_MISMATCH(); | ||||
return AVERROR(EIO); | return AVERROR(EIO); | ||||
} | } | ||||
if (model_input->dt != DNN_FLOAT && model_input->dt != DNN_UINT8) { | |||||
av_log(ctx, AV_LOG_ERROR, "only support dnn models with input data type as float32 and uint8.\n"); | |||||
return AVERROR(EIO); | |||||
} | |||||
return 0; | |||||
case AV_PIX_FMT_GRAY8: | |||||
if (model_input->channels != 1) { | |||||
LOG_FORMAT_CHANNEL_MISMATCH(); | |||||
return AVERROR(EIO); | |||||
} | |||||
if (model_input->dt != DNN_UINT8) { | |||||
av_log(ctx, AV_LOG_ERROR, "only support dnn models with input data type uint8.\n"); | |||||
return AVERROR(EIO); | |||||
} | |||||
return 0; | return 0; | ||||
case AV_PIX_FMT_GRAYF32: | case AV_PIX_FMT_GRAYF32: | ||||
case AV_PIX_FMT_YUV420P: | case AV_PIX_FMT_YUV420P: | ||||
@@ -181,10 +165,6 @@ static int check_modelinput_inlink(const DNNData *model_input, const AVFilterLin | |||||
LOG_FORMAT_CHANNEL_MISMATCH(); | LOG_FORMAT_CHANNEL_MISMATCH(); | ||||
return AVERROR(EIO); | return AVERROR(EIO); | ||||
} | } | ||||
if (model_input->dt != DNN_FLOAT) { | |||||
av_log(ctx, AV_LOG_ERROR, "only support dnn models with input data type float32.\n"); | |||||
return AVERROR(EIO); | |||||
} | |||||
return 0; | return 0; | ||||
default: | default: | ||||
av_log(ctx, AV_LOG_ERROR, "%s not supported.\n", av_get_pix_fmt_name(fmt)); | av_log(ctx, AV_LOG_ERROR, "%s not supported.\n", av_get_pix_fmt_name(fmt)); | ||||
@@ -213,74 +193,24 @@ static int config_input(AVFilterLink *inlink) | |||||
return check; | return check; | ||||
} | } | ||||
ctx->input.width = inlink->w; | |||||
ctx->input.height = inlink->h; | |||||
ctx->input.channels = model_input.channels; | |||||
ctx->input.dt = model_input.dt; | |||||
result = (ctx->model->set_input)(ctx->model->model, | |||||
&ctx->input, ctx->model_inputname); | |||||
if (result != DNN_SUCCESS) { | |||||
av_log(ctx, AV_LOG_ERROR, "could not set input and output for the model\n"); | |||||
return AVERROR(EIO); | |||||
} | |||||
return 0; | return 0; | ||||
} | } | ||||
static int prepare_sws_context(AVFilterLink *outlink) | |||||
static av_always_inline int isPlanarYUV(enum AVPixelFormat pix_fmt) | |||||
{ | |||||
const AVPixFmtDescriptor *desc = av_pix_fmt_desc_get(pix_fmt); | |||||
av_assert0(desc); | |||||
return !(desc->flags & AV_PIX_FMT_FLAG_RGB) && desc->nb_components == 3; | |||||
} | |||||
static int prepare_uv_scale(AVFilterLink *outlink) | |||||
{ | { | ||||
AVFilterContext *context = outlink->src; | AVFilterContext *context = outlink->src; | ||||
DnnProcessingContext *ctx = context->priv; | DnnProcessingContext *ctx = context->priv; | ||||
AVFilterLink *inlink = context->inputs[0]; | AVFilterLink *inlink = context->inputs[0]; | ||||
enum AVPixelFormat fmt = inlink->format; | enum AVPixelFormat fmt = inlink->format; | ||||
DNNDataType input_dt = ctx->input.dt; | |||||
DNNDataType output_dt = ctx->output.dt; | |||||
switch (fmt) { | |||||
case AV_PIX_FMT_RGB24: | |||||
case AV_PIX_FMT_BGR24: | |||||
if (input_dt == DNN_FLOAT) { | |||||
ctx->sws_gray8_to_grayf32 = sws_getContext(inlink->w * 3, | |||||
inlink->h, | |||||
AV_PIX_FMT_GRAY8, | |||||
inlink->w * 3, | |||||
inlink->h, | |||||
AV_PIX_FMT_GRAYF32, | |||||
0, NULL, NULL, NULL); | |||||
} | |||||
if (output_dt == DNN_FLOAT) { | |||||
ctx->sws_grayf32_to_gray8 = sws_getContext(outlink->w * 3, | |||||
outlink->h, | |||||
AV_PIX_FMT_GRAYF32, | |||||
outlink->w * 3, | |||||
outlink->h, | |||||
AV_PIX_FMT_GRAY8, | |||||
0, NULL, NULL, NULL); | |||||
} | |||||
return 0; | |||||
case AV_PIX_FMT_YUV420P: | |||||
case AV_PIX_FMT_YUV422P: | |||||
case AV_PIX_FMT_YUV444P: | |||||
case AV_PIX_FMT_YUV410P: | |||||
case AV_PIX_FMT_YUV411P: | |||||
av_assert0(input_dt == DNN_FLOAT); | |||||
av_assert0(output_dt == DNN_FLOAT); | |||||
ctx->sws_gray8_to_grayf32 = sws_getContext(inlink->w, | |||||
inlink->h, | |||||
AV_PIX_FMT_GRAY8, | |||||
inlink->w, | |||||
inlink->h, | |||||
AV_PIX_FMT_GRAYF32, | |||||
0, NULL, NULL, NULL); | |||||
ctx->sws_grayf32_to_gray8 = sws_getContext(outlink->w, | |||||
outlink->h, | |||||
AV_PIX_FMT_GRAYF32, | |||||
outlink->w, | |||||
outlink->h, | |||||
AV_PIX_FMT_GRAY8, | |||||
0, NULL, NULL, NULL); | |||||
if (isPlanarYUV(fmt)) { | |||||
if (inlink->w != outlink->w || inlink->h != outlink->h) { | if (inlink->w != outlink->w || inlink->h != outlink->h) { | ||||
const AVPixFmtDescriptor *desc = av_pix_fmt_desc_get(fmt); | const AVPixFmtDescriptor *desc = av_pix_fmt_desc_get(fmt); | ||||
int sws_src_h = AV_CEIL_RSHIFT(inlink->h, desc->log2_chroma_h); | int sws_src_h = AV_CEIL_RSHIFT(inlink->h, desc->log2_chroma_h); | ||||
@@ -292,10 +222,6 @@ static int prepare_sws_context(AVFilterLink *outlink) | |||||
SWS_BICUBIC, NULL, NULL, NULL); | SWS_BICUBIC, NULL, NULL, NULL); | ||||
ctx->sws_uv_height = sws_src_h; | ctx->sws_uv_height = sws_src_h; | ||||
} | } | ||||
return 0; | |||||
default: | |||||
//do nothing | |||||
break; | |||||
} | } | ||||
return 0; | return 0; | ||||
@@ -306,120 +232,34 @@ static int config_output(AVFilterLink *outlink) | |||||
AVFilterContext *context = outlink->src; | AVFilterContext *context = outlink->src; | ||||
DnnProcessingContext *ctx = context->priv; | DnnProcessingContext *ctx = context->priv; | ||||
DNNReturnType result; | DNNReturnType result; | ||||
AVFilterLink *inlink = context->inputs[0]; | |||||
AVFrame *out = NULL; | |||||
// have a try run in case that the dnn model resize the frame | |||||
result = (ctx->dnn_module->execute_model)(ctx->model, &ctx->output, (const char **)&ctx->model_outputname, 1); | |||||
if (result != DNN_SUCCESS){ | |||||
av_log(ctx, AV_LOG_ERROR, "failed to execute model\n"); | |||||
AVFrame *fake_in = ff_get_video_buffer(inlink, inlink->w, inlink->h); | |||||
result = (ctx->model->set_input)(ctx->model->model, fake_in, ctx->model_inputname); | |||||
if (result != DNN_SUCCESS) { | |||||
av_log(ctx, AV_LOG_ERROR, "could not set input for the model\n"); | |||||
return AVERROR(EIO); | return AVERROR(EIO); | ||||
} | } | ||||
outlink->w = ctx->output.width; | |||||
outlink->h = ctx->output.height; | |||||
prepare_sws_context(outlink); | |||||
return 0; | |||||
} | |||||
static int copy_from_frame_to_dnn(DnnProcessingContext *ctx, const AVFrame *frame) | |||||
{ | |||||
int bytewidth = av_image_get_linesize(frame->format, frame->width, 0); | |||||
DNNData *dnn_input = &ctx->input; | |||||
switch (frame->format) { | |||||
case AV_PIX_FMT_RGB24: | |||||
case AV_PIX_FMT_BGR24: | |||||
if (dnn_input->dt == DNN_FLOAT) { | |||||
sws_scale(ctx->sws_gray8_to_grayf32, (const uint8_t **)frame->data, frame->linesize, | |||||
0, frame->height, (uint8_t * const*)(&dnn_input->data), | |||||
(const int [4]){frame->width * 3 * sizeof(float), 0, 0, 0}); | |||||
} else { | |||||
av_assert0(dnn_input->dt == DNN_UINT8); | |||||
av_image_copy_plane(dnn_input->data, bytewidth, | |||||
frame->data[0], frame->linesize[0], | |||||
bytewidth, frame->height); | |||||
} | |||||
return 0; | |||||
case AV_PIX_FMT_GRAY8: | |||||
case AV_PIX_FMT_GRAYF32: | |||||
av_image_copy_plane(dnn_input->data, bytewidth, | |||||
frame->data[0], frame->linesize[0], | |||||
bytewidth, frame->height); | |||||
return 0; | |||||
case AV_PIX_FMT_YUV420P: | |||||
case AV_PIX_FMT_YUV422P: | |||||
case AV_PIX_FMT_YUV444P: | |||||
case AV_PIX_FMT_YUV410P: | |||||
case AV_PIX_FMT_YUV411P: | |||||
sws_scale(ctx->sws_gray8_to_grayf32, (const uint8_t **)frame->data, frame->linesize, | |||||
0, frame->height, (uint8_t * const*)(&dnn_input->data), | |||||
(const int [4]){frame->width * sizeof(float), 0, 0, 0}); | |||||
return 0; | |||||
default: | |||||
// have a try run in case that the dnn model resize the frame | |||||
out = ff_get_video_buffer(inlink, inlink->w, inlink->h); | |||||
result = (ctx->dnn_module->execute_model)(ctx->model, (const char **)&ctx->model_outputname, 1, out); | |||||
if (result != DNN_SUCCESS){ | |||||
av_log(ctx, AV_LOG_ERROR, "failed to execute model\n"); | |||||
return AVERROR(EIO); | return AVERROR(EIO); | ||||
} | } | ||||
return 0; | |||||
} | |||||
outlink->w = out->width; | |||||
outlink->h = out->height; | |||||
static int copy_from_dnn_to_frame(DnnProcessingContext *ctx, AVFrame *frame) | |||||
{ | |||||
int bytewidth = av_image_get_linesize(frame->format, frame->width, 0); | |||||
DNNData *dnn_output = &ctx->output; | |||||
switch (frame->format) { | |||||
case AV_PIX_FMT_RGB24: | |||||
case AV_PIX_FMT_BGR24: | |||||
if (dnn_output->dt == DNN_FLOAT) { | |||||
sws_scale(ctx->sws_grayf32_to_gray8, (const uint8_t *[4]){(const uint8_t *)dnn_output->data, 0, 0, 0}, | |||||
(const int[4]){frame->width * 3 * sizeof(float), 0, 0, 0}, | |||||
0, frame->height, (uint8_t * const*)frame->data, frame->linesize); | |||||
} else { | |||||
av_assert0(dnn_output->dt == DNN_UINT8); | |||||
av_image_copy_plane(frame->data[0], frame->linesize[0], | |||||
dnn_output->data, bytewidth, | |||||
bytewidth, frame->height); | |||||
} | |||||
return 0; | |||||
case AV_PIX_FMT_GRAY8: | |||||
// it is possible that data type of dnn output is float32, | |||||
// need to add support for such case when needed. | |||||
av_assert0(dnn_output->dt == DNN_UINT8); | |||||
av_image_copy_plane(frame->data[0], frame->linesize[0], | |||||
dnn_output->data, bytewidth, | |||||
bytewidth, frame->height); | |||||
return 0; | |||||
case AV_PIX_FMT_GRAYF32: | |||||
av_assert0(dnn_output->dt == DNN_FLOAT); | |||||
av_image_copy_plane(frame->data[0], frame->linesize[0], | |||||
dnn_output->data, bytewidth, | |||||
bytewidth, frame->height); | |||||
return 0; | |||||
case AV_PIX_FMT_YUV420P: | |||||
case AV_PIX_FMT_YUV422P: | |||||
case AV_PIX_FMT_YUV444P: | |||||
case AV_PIX_FMT_YUV410P: | |||||
case AV_PIX_FMT_YUV411P: | |||||
sws_scale(ctx->sws_grayf32_to_gray8, (const uint8_t *[4]){(const uint8_t *)dnn_output->data, 0, 0, 0}, | |||||
(const int[4]){frame->width * sizeof(float), 0, 0, 0}, | |||||
0, frame->height, (uint8_t * const*)frame->data, frame->linesize); | |||||
return 0; | |||||
default: | |||||
return AVERROR(EIO); | |||||
} | |||||
av_frame_free(&fake_in); | |||||
av_frame_free(&out); | |||||
prepare_uv_scale(outlink); | |||||
return 0; | return 0; | ||||
} | } | ||||
static av_always_inline int isPlanarYUV(enum AVPixelFormat pix_fmt) | |||||
{ | |||||
const AVPixFmtDescriptor *desc = av_pix_fmt_desc_get(pix_fmt); | |||||
av_assert0(desc); | |||||
return !(desc->flags & AV_PIX_FMT_FLAG_RGB) && desc->nb_components == 3; | |||||
} | |||||
static int copy_uv_planes(DnnProcessingContext *ctx, AVFrame *out, const AVFrame *in) | static int copy_uv_planes(DnnProcessingContext *ctx, AVFrame *out, const AVFrame *in) | ||||
{ | { | ||||
const AVPixFmtDescriptor *desc; | const AVPixFmtDescriptor *desc; | ||||
@@ -453,11 +293,9 @@ static int filter_frame(AVFilterLink *inlink, AVFrame *in) | |||||
DNNReturnType dnn_result; | DNNReturnType dnn_result; | ||||
AVFrame *out; | AVFrame *out; | ||||
copy_from_frame_to_dnn(ctx, in); | |||||
dnn_result = (ctx->dnn_module->execute_model)(ctx->model, &ctx->output, (const char **)&ctx->model_outputname, 1); | |||||
if (dnn_result != DNN_SUCCESS){ | |||||
av_log(ctx, AV_LOG_ERROR, "failed to execute model\n"); | |||||
dnn_result = (ctx->model->set_input)(ctx->model->model, in, ctx->model_inputname); | |||||
if (dnn_result != DNN_SUCCESS) { | |||||
av_log(ctx, AV_LOG_ERROR, "could not set input for the model\n"); | |||||
av_frame_free(&in); | av_frame_free(&in); | ||||
return AVERROR(EIO); | return AVERROR(EIO); | ||||
} | } | ||||
@@ -467,9 +305,15 @@ static int filter_frame(AVFilterLink *inlink, AVFrame *in) | |||||
av_frame_free(&in); | av_frame_free(&in); | ||||
return AVERROR(ENOMEM); | return AVERROR(ENOMEM); | ||||
} | } | ||||
av_frame_copy_props(out, in); | av_frame_copy_props(out, in); | ||||
copy_from_dnn_to_frame(ctx, out); | |||||
dnn_result = (ctx->dnn_module->execute_model)(ctx->model, (const char **)&ctx->model_outputname, 1, out); | |||||
if (dnn_result != DNN_SUCCESS){ | |||||
av_log(ctx, AV_LOG_ERROR, "failed to execute model\n"); | |||||
av_frame_free(&in); | |||||
av_frame_free(&out); | |||||
return AVERROR(EIO); | |||||
} | |||||
if (isPlanarYUV(in->format)) | if (isPlanarYUV(in->format)) | ||||
copy_uv_planes(ctx, out, in); | copy_uv_planes(ctx, out, in); | ||||
@@ -482,8 +326,6 @@ static av_cold void uninit(AVFilterContext *ctx) | |||||
{ | { | ||||
DnnProcessingContext *context = ctx->priv; | DnnProcessingContext *context = ctx->priv; | ||||
sws_freeContext(context->sws_gray8_to_grayf32); | |||||
sws_freeContext(context->sws_grayf32_to_gray8); | |||||
sws_freeContext(context->sws_uv_scale); | sws_freeContext(context->sws_uv_scale); | ||||
if (context->dnn_module) | if (context->dnn_module) | ||||
@@ -41,11 +41,10 @@ typedef struct SRContext { | |||||
DNNBackendType backend_type; | DNNBackendType backend_type; | ||||
DNNModule *dnn_module; | DNNModule *dnn_module; | ||||
DNNModel *model; | DNNModel *model; | ||||
DNNData input; | |||||
DNNData output; | |||||
int scale_factor; | int scale_factor; | ||||
struct SwsContext *sws_contexts[3]; | |||||
int sws_slice_h, sws_input_linesize, sws_output_linesize; | |||||
struct SwsContext *sws_uv_scale; | |||||
int sws_uv_height; | |||||
struct SwsContext *sws_pre_scale; | |||||
} SRContext; | } SRContext; | ||||
#define OFFSET(x) offsetof(SRContext, x) | #define OFFSET(x) offsetof(SRContext, x) | ||||
@@ -87,11 +86,6 @@ static av_cold int init(AVFilterContext *context) | |||||
return AVERROR(EIO); | return AVERROR(EIO); | ||||
} | } | ||||
sr_context->input.dt = DNN_FLOAT; | |||||
sr_context->sws_contexts[0] = NULL; | |||||
sr_context->sws_contexts[1] = NULL; | |||||
sr_context->sws_contexts[2] = NULL; | |||||
return 0; | return 0; | ||||
} | } | ||||
@@ -111,95 +105,63 @@ static int query_formats(AVFilterContext *context) | |||||
return ff_set_common_formats(context, formats_list); | return ff_set_common_formats(context, formats_list); | ||||
} | } | ||||
static int config_props(AVFilterLink *inlink) | |||||
static int config_output(AVFilterLink *outlink) | |||||
{ | { | ||||
AVFilterContext *context = inlink->dst; | |||||
SRContext *sr_context = context->priv; | |||||
AVFilterLink *outlink = context->outputs[0]; | |||||
AVFilterContext *context = outlink->src; | |||||
SRContext *ctx = context->priv; | |||||
DNNReturnType result; | DNNReturnType result; | ||||
int sws_src_h, sws_src_w, sws_dst_h, sws_dst_w; | |||||
AVFilterLink *inlink = context->inputs[0]; | |||||
AVFrame *out = NULL; | |||||
const char *model_output_name = "y"; | const char *model_output_name = "y"; | ||||
sr_context->input.width = inlink->w * sr_context->scale_factor; | |||||
sr_context->input.height = inlink->h * sr_context->scale_factor; | |||||
sr_context->input.channels = 1; | |||||
result = (sr_context->model->set_input)(sr_context->model->model, &sr_context->input, "x"); | |||||
if (result != DNN_SUCCESS){ | |||||
av_log(context, AV_LOG_ERROR, "could not set input and output for the model\n"); | |||||
AVFrame *fake_in = ff_get_video_buffer(inlink, inlink->w, inlink->h); | |||||
result = (ctx->model->set_input)(ctx->model->model, fake_in, "x"); | |||||
if (result != DNN_SUCCESS) { | |||||
av_log(context, AV_LOG_ERROR, "could not set input for the model\n"); | |||||
return AVERROR(EIO); | return AVERROR(EIO); | ||||
} | } | ||||
result = (sr_context->dnn_module->execute_model)(sr_context->model, &sr_context->output, &model_output_name, 1); | |||||
// have a try run in case that the dnn model resize the frame | |||||
out = ff_get_video_buffer(inlink, inlink->w, inlink->h); | |||||
result = (ctx->dnn_module->execute_model)(ctx->model, (const char **)&model_output_name, 1, out); | |||||
if (result != DNN_SUCCESS){ | if (result != DNN_SUCCESS){ | ||||
av_log(context, AV_LOG_ERROR, "failed to execute loaded model\n"); | av_log(context, AV_LOG_ERROR, "failed to execute loaded model\n"); | ||||
return AVERROR(EIO); | return AVERROR(EIO); | ||||
} | } | ||||
if (sr_context->input.height != sr_context->output.height || sr_context->input.width != sr_context->output.width){ | |||||
sr_context->input.width = inlink->w; | |||||
sr_context->input.height = inlink->h; | |||||
result = (sr_context->model->set_input)(sr_context->model->model, &sr_context->input, "x"); | |||||
if (result != DNN_SUCCESS){ | |||||
av_log(context, AV_LOG_ERROR, "could not set input and output for the model\n"); | |||||
return AVERROR(EIO); | |||||
} | |||||
result = (sr_context->dnn_module->execute_model)(sr_context->model, &sr_context->output, &model_output_name, 1); | |||||
if (result != DNN_SUCCESS){ | |||||
av_log(context, AV_LOG_ERROR, "failed to execute loaded model\n"); | |||||
return AVERROR(EIO); | |||||
} | |||||
sr_context->scale_factor = 0; | |||||
} | |||||
outlink->h = sr_context->output.height; | |||||
outlink->w = sr_context->output.width; | |||||
sr_context->sws_contexts[1] = sws_getContext(sr_context->input.width, sr_context->input.height, AV_PIX_FMT_GRAY8, | |||||
sr_context->input.width, sr_context->input.height, AV_PIX_FMT_GRAYF32, | |||||
0, NULL, NULL, NULL); | |||||
sr_context->sws_input_linesize = sr_context->input.width << 2; | |||||
sr_context->sws_contexts[2] = sws_getContext(sr_context->output.width, sr_context->output.height, AV_PIX_FMT_GRAYF32, | |||||
sr_context->output.width, sr_context->output.height, AV_PIX_FMT_GRAY8, | |||||
0, NULL, NULL, NULL); | |||||
sr_context->sws_output_linesize = sr_context->output.width << 2; | |||||
if (!sr_context->sws_contexts[1] || !sr_context->sws_contexts[2]){ | |||||
av_log(context, AV_LOG_ERROR, "could not create SwsContext for conversions\n"); | |||||
return AVERROR(ENOMEM); | |||||
} | |||||
if (sr_context->scale_factor){ | |||||
sr_context->sws_contexts[0] = sws_getContext(inlink->w, inlink->h, inlink->format, | |||||
outlink->w, outlink->h, outlink->format, | |||||
SWS_BICUBIC, NULL, NULL, NULL); | |||||
if (!sr_context->sws_contexts[0]){ | |||||
av_log(context, AV_LOG_ERROR, "could not create SwsContext for scaling\n"); | |||||
return AVERROR(ENOMEM); | |||||
} | |||||
sr_context->sws_slice_h = inlink->h; | |||||
} else { | |||||
if (fake_in->width != out->width || fake_in->height != out->height) { | |||||
//espcn | |||||
outlink->w = out->width; | |||||
outlink->h = out->height; | |||||
if (inlink->format != AV_PIX_FMT_GRAY8){ | if (inlink->format != AV_PIX_FMT_GRAY8){ | ||||
const AVPixFmtDescriptor *desc = av_pix_fmt_desc_get(inlink->format); | const AVPixFmtDescriptor *desc = av_pix_fmt_desc_get(inlink->format); | ||||
sws_src_h = AV_CEIL_RSHIFT(sr_context->input.height, desc->log2_chroma_h); | |||||
sws_src_w = AV_CEIL_RSHIFT(sr_context->input.width, desc->log2_chroma_w); | |||||
sws_dst_h = AV_CEIL_RSHIFT(sr_context->output.height, desc->log2_chroma_h); | |||||
sws_dst_w = AV_CEIL_RSHIFT(sr_context->output.width, desc->log2_chroma_w); | |||||
sr_context->sws_contexts[0] = sws_getContext(sws_src_w, sws_src_h, AV_PIX_FMT_GRAY8, | |||||
sws_dst_w, sws_dst_h, AV_PIX_FMT_GRAY8, | |||||
SWS_BICUBIC, NULL, NULL, NULL); | |||||
if (!sr_context->sws_contexts[0]){ | |||||
av_log(context, AV_LOG_ERROR, "could not create SwsContext for scaling\n"); | |||||
return AVERROR(ENOMEM); | |||||
} | |||||
sr_context->sws_slice_h = sws_src_h; | |||||
int sws_src_h = AV_CEIL_RSHIFT(inlink->h, desc->log2_chroma_h); | |||||
int sws_src_w = AV_CEIL_RSHIFT(inlink->w, desc->log2_chroma_w); | |||||
int sws_dst_h = AV_CEIL_RSHIFT(outlink->h, desc->log2_chroma_h); | |||||
int sws_dst_w = AV_CEIL_RSHIFT(outlink->w, desc->log2_chroma_w); | |||||
ctx->sws_uv_scale = sws_getContext(sws_src_w, sws_src_h, AV_PIX_FMT_GRAY8, | |||||
sws_dst_w, sws_dst_h, AV_PIX_FMT_GRAY8, | |||||
SWS_BICUBIC, NULL, NULL, NULL); | |||||
ctx->sws_uv_height = sws_src_h; | |||||
} | } | ||||
} else { | |||||
//srcnn | |||||
outlink->w = out->width * ctx->scale_factor; | |||||
outlink->h = out->height * ctx->scale_factor; | |||||
ctx->sws_pre_scale = sws_getContext(inlink->w, inlink->h, inlink->format, | |||||
outlink->w, outlink->h, outlink->format, | |||||
SWS_BICUBIC, NULL, NULL, NULL); | |||||
} | } | ||||
av_frame_free(&fake_in); | |||||
av_frame_free(&out); | |||||
return 0; | return 0; | ||||
} | } | ||||
static int filter_frame(AVFilterLink *inlink, AVFrame *in) | static int filter_frame(AVFilterLink *inlink, AVFrame *in) | ||||
{ | { | ||||
AVFilterContext *context = inlink->dst; | AVFilterContext *context = inlink->dst; | ||||
SRContext *sr_context = context->priv; | |||||
SRContext *ctx = context->priv; | |||||
AVFilterLink *outlink = context->outputs[0]; | AVFilterLink *outlink = context->outputs[0]; | ||||
AVFrame *out = ff_get_video_buffer(outlink, outlink->w, outlink->h); | AVFrame *out = ff_get_video_buffer(outlink, outlink->w, outlink->h); | ||||
DNNReturnType dnn_result; | DNNReturnType dnn_result; | ||||
@@ -211,45 +173,44 @@ static int filter_frame(AVFilterLink *inlink, AVFrame *in) | |||||
return AVERROR(ENOMEM); | return AVERROR(ENOMEM); | ||||
} | } | ||||
av_frame_copy_props(out, in); | av_frame_copy_props(out, in); | ||||
out->height = sr_context->output.height; | |||||
out->width = sr_context->output.width; | |||||
if (sr_context->scale_factor){ | |||||
sws_scale(sr_context->sws_contexts[0], (const uint8_t **)in->data, in->linesize, | |||||
0, sr_context->sws_slice_h, out->data, out->linesize); | |||||
sws_scale(sr_context->sws_contexts[1], (const uint8_t **)out->data, out->linesize, | |||||
0, out->height, (uint8_t * const*)(&sr_context->input.data), | |||||
(const int [4]){sr_context->sws_input_linesize, 0, 0, 0}); | |||||
if (ctx->sws_pre_scale) { | |||||
sws_scale(ctx->sws_pre_scale, | |||||
(const uint8_t **)in->data, in->linesize, 0, in->height, | |||||
out->data, out->linesize); | |||||
dnn_result = (ctx->model->set_input)(ctx->model->model, out, "x"); | |||||
} else { | } else { | ||||
if (sr_context->sws_contexts[0]){ | |||||
sws_scale(sr_context->sws_contexts[0], (const uint8_t **)(in->data + 1), in->linesize + 1, | |||||
0, sr_context->sws_slice_h, out->data + 1, out->linesize + 1); | |||||
sws_scale(sr_context->sws_contexts[0], (const uint8_t **)(in->data + 2), in->linesize + 2, | |||||
0, sr_context->sws_slice_h, out->data + 2, out->linesize + 2); | |||||
} | |||||
dnn_result = (ctx->model->set_input)(ctx->model->model, in, "x"); | |||||
} | |||||
sws_scale(sr_context->sws_contexts[1], (const uint8_t **)in->data, in->linesize, | |||||
0, in->height, (uint8_t * const*)(&sr_context->input.data), | |||||
(const int [4]){sr_context->sws_input_linesize, 0, 0, 0}); | |||||
if (dnn_result != DNN_SUCCESS) { | |||||
av_frame_free(&in); | |||||
av_frame_free(&out); | |||||
av_log(context, AV_LOG_ERROR, "could not set input for the model\n"); | |||||
return AVERROR(EIO); | |||||
} | } | ||||
av_frame_free(&in); | |||||
dnn_result = (sr_context->dnn_module->execute_model)(sr_context->model, &sr_context->output, &model_output_name, 1); | |||||
dnn_result = (ctx->dnn_module->execute_model)(ctx->model, (const char **)&model_output_name, 1, out); | |||||
if (dnn_result != DNN_SUCCESS){ | if (dnn_result != DNN_SUCCESS){ | ||||
av_log(context, AV_LOG_ERROR, "failed to execute loaded model\n"); | |||||
av_log(ctx, AV_LOG_ERROR, "failed to execute loaded model\n"); | |||||
av_frame_free(&in); | |||||
av_frame_free(&out); | |||||
return AVERROR(EIO); | return AVERROR(EIO); | ||||
} | } | ||||
sws_scale(sr_context->sws_contexts[2], (const uint8_t *[4]){(const uint8_t *)sr_context->output.data, 0, 0, 0}, | |||||
(const int[4]){sr_context->sws_output_linesize, 0, 0, 0}, | |||||
0, out->height, (uint8_t * const*)out->data, out->linesize); | |||||
if (ctx->sws_uv_scale) { | |||||
sws_scale(ctx->sws_uv_scale, (const uint8_t **)(in->data + 1), in->linesize + 1, | |||||
0, ctx->sws_uv_height, out->data + 1, out->linesize + 1); | |||||
sws_scale(ctx->sws_uv_scale, (const uint8_t **)(in->data + 2), in->linesize + 2, | |||||
0, ctx->sws_uv_height, out->data + 2, out->linesize + 2); | |||||
} | |||||
av_frame_free(&in); | |||||
return ff_filter_frame(outlink, out); | return ff_filter_frame(outlink, out); | ||||
} | } | ||||
static av_cold void uninit(AVFilterContext *context) | static av_cold void uninit(AVFilterContext *context) | ||||
{ | { | ||||
int i; | |||||
SRContext *sr_context = context->priv; | SRContext *sr_context = context->priv; | ||||
if (sr_context->dnn_module){ | if (sr_context->dnn_module){ | ||||
@@ -257,16 +218,14 @@ static av_cold void uninit(AVFilterContext *context) | |||||
av_freep(&sr_context->dnn_module); | av_freep(&sr_context->dnn_module); | ||||
} | } | ||||
for (i = 0; i < 3; ++i){ | |||||
sws_freeContext(sr_context->sws_contexts[i]); | |||||
} | |||||
sws_freeContext(sr_context->sws_uv_scale); | |||||
sws_freeContext(sr_context->sws_pre_scale); | |||||
} | } | ||||
static const AVFilterPad sr_inputs[] = { | static const AVFilterPad sr_inputs[] = { | ||||
{ | { | ||||
.name = "default", | .name = "default", | ||||
.type = AVMEDIA_TYPE_VIDEO, | .type = AVMEDIA_TYPE_VIDEO, | ||||
.config_props = config_props, | |||||
.filter_frame = filter_frame, | .filter_frame = filter_frame, | ||||
}, | }, | ||||
{ NULL } | { NULL } | ||||
@@ -275,6 +234,7 @@ static const AVFilterPad sr_inputs[] = { | |||||
static const AVFilterPad sr_outputs[] = { | static const AVFilterPad sr_outputs[] = { | ||||
{ | { | ||||
.name = "default", | .name = "default", | ||||
.config_props = config_output, | |||||
.type = AVMEDIA_TYPE_VIDEO, | .type = AVMEDIA_TYPE_VIDEO, | ||||
}, | }, | ||||
{ NULL } | { NULL } | ||||