From 4d9470dce260f95dcea5b0aaa73c235283690eb1 Mon Sep 17 00:00:00 2001 From: Rs Date: Tue, 22 Jan 2019 13:56:56 +0100 Subject: [PATCH] Fix Bag-input --- mirror/server.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/mirror/server.py b/mirror/server.py index 471b438..2b8689c 100644 --- a/mirror/server.py +++ b/mirror/server.py @@ -28,8 +28,14 @@ def build(self, inputs, model, visualisations=[]): if len(inputs) <= 0: raise ValueError('At least one input is required.') self.inputs, self.model = inputs, model - - self.current_input = self.inputs[0].unsqueeze(0).to(self.device) # add 1 dim for batch + + #Iterate bag of inputs + for i in range(len(self.inputs)): + # Add mini-batch dim if not exist + if len(self.inputs[i].size()) == 1: + self.inputs[i] = self.inputs[i].unsqueeze(0).to(self.device) + + self.current_input = self.inputs model = model.to(self.device) model.eval() # instantiate a Tracer object to create a graph from the model @@ -172,4 +178,4 @@ def api_model_layer_output_image(input_id, vis_id, layer_id, time, output_id): except KeyError: return Response(status=500, response='Index not found.') - return app \ No newline at end of file + return app