|
|
|
@@ -76,7 +76,7 @@ static TF_Buffer *read_graph(const char *model_filename) |
|
|
|
return graph_buf; |
|
|
|
} |
|
|
|
|
|
|
|
static DNNReturnType set_input_output_tf(void *model, DNNData *input, DNNData *output) |
|
|
|
static DNNReturnType set_input_output_tf(void *model, DNNData *input, const char *input_name, DNNData *output, const char *output_name) |
|
|
|
{ |
|
|
|
TFModel *tf_model = (TFModel *)model; |
|
|
|
int64_t input_dims[] = {1, input->height, input->width, input->channels}; |
|
|
|
@@ -84,8 +84,8 @@ static DNNReturnType set_input_output_tf(void *model, DNNData *input, DNNData *o |
|
|
|
const TF_Operation *init_op = TF_GraphOperationByName(tf_model->graph, "init"); |
|
|
|
TF_Tensor *output_tensor; |
|
|
|
|
|
|
|
// Input operation should be named 'x' |
|
|
|
tf_model->input.oper = TF_GraphOperationByName(tf_model->graph, "x"); |
|
|
|
// Input operation |
|
|
|
tf_model->input.oper = TF_GraphOperationByName(tf_model->graph, input_name); |
|
|
|
if (!tf_model->input.oper){ |
|
|
|
return DNN_ERROR; |
|
|
|
} |
|
|
|
@@ -100,8 +100,8 @@ static DNNReturnType set_input_output_tf(void *model, DNNData *input, DNNData *o |
|
|
|
} |
|
|
|
input->data = (float *)TF_TensorData(tf_model->input_tensor); |
|
|
|
|
|
|
|
// Output operation should be named 'y' |
|
|
|
tf_model->output.oper = TF_GraphOperationByName(tf_model->graph, "y"); |
|
|
|
// Output operation |
|
|
|
tf_model->output.oper = TF_GraphOperationByName(tf_model->graph, output_name); |
|
|
|
if (!tf_model->output.oper){ |
|
|
|
return DNN_ERROR; |
|
|
|
} |
|
|
|
|