|
|
@@ -79,10 +79,31 @@ static TF_Buffer *read_graph(const char *model_filename) |
|
|
|
return graph_buf; |
|
|
|
} |
|
|
|
|
|
|
|
static DNNReturnType set_input_output_tf(void *model, DNNData *input, const char *input_name, const char **output_names, uint32_t nb_output) |
|
|
|
static TF_Tensor *allocate_input_tensor(const DNNInputData *input) |
|
|
|
{ |
|
|
|
TFModel *tf_model = (TFModel *)model; |
|
|
|
TF_DataType dt; |
|
|
|
size_t size; |
|
|
|
int64_t input_dims[] = {1, input->height, input->width, input->channels}; |
|
|
|
switch (input->dt) { |
|
|
|
case DNN_FLOAT: |
|
|
|
dt = TF_FLOAT; |
|
|
|
size = sizeof(float); |
|
|
|
break; |
|
|
|
case DNN_UINT8: |
|
|
|
dt = TF_UINT8; |
|
|
|
size = sizeof(char); |
|
|
|
break; |
|
|
|
default: |
|
|
|
av_assert0(!"should not reach here"); |
|
|
|
} |
|
|
|
|
|
|
|
return TF_AllocateTensor(dt, input_dims, 4, |
|
|
|
input_dims[1] * input_dims[2] * input_dims[3] * size); |
|
|
|
} |
|
|
|
|
|
|
|
static DNNReturnType set_input_output_tf(void *model, DNNInputData *input, const char *input_name, const char **output_names, uint32_t nb_output) |
|
|
|
{ |
|
|
|
TFModel *tf_model = (TFModel *)model; |
|
|
|
TF_SessionOptions *sess_opts; |
|
|
|
const TF_Operation *init_op = TF_GraphOperationByName(tf_model->graph, "init"); |
|
|
|
|
|
|
@@ -95,8 +116,7 @@ static DNNReturnType set_input_output_tf(void *model, DNNData *input, const char |
|
|
|
if (tf_model->input_tensor){ |
|
|
|
TF_DeleteTensor(tf_model->input_tensor); |
|
|
|
} |
|
|
|
tf_model->input_tensor = TF_AllocateTensor(TF_FLOAT, input_dims, 4, |
|
|
|
input_dims[1] * input_dims[2] * input_dims[3] * sizeof(float)); |
|
|
|
tf_model->input_tensor = allocate_input_tensor(input); |
|
|
|
if (!tf_model->input_tensor){ |
|
|
|
return DNN_ERROR; |
|
|
|
} |
|
|
|