diff --git a/mirror/server.py b/mirror/server.py index 471b438..2b8689c 100644 --- a/mirror/server.py +++ b/mirror/server.py @@ -28,8 +28,14 @@ def build(self, inputs, model, visualisations=[]): if len(inputs) <= 0: raise ValueError('At least one input is required.') self.inputs, self.model = inputs, model - - self.current_input = self.inputs[0].unsqueeze(0).to(self.device) # add 1 dim for batch + + #Iterate bag of inputs + for i in range(len(self.inputs)): + # Add mini-batch dim if not exist + if len(self.inputs[i].size()) == 1: + self.inputs[i] = self.inputs[i].unsqueeze(0).to(self.device) + + self.current_input = self.inputs model = model.to(self.device) model.eval() # instantiate a Tracer object to create a graph from the model @@ -172,4 +178,4 @@ def api_model_layer_output_image(input_id, vis_id, layer_id, time, output_id): except KeyError: return Response(status=500, response='Index not found.') - return app \ No newline at end of file + return app