Skip to content

Commit 7482cf2

Browse files
authored
Fix flux tuning device issue (#2352)
1 parent 0a87139 commit 7482cf2

File tree

3 files changed

+8
-8
lines changed

3 files changed

+8
-8
lines changed

examples/pytorch/diffusion_model/diffusers/flux/README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@ This example quantizes and validates the accuracy of Flux.
88

99
```shell
1010
pip install -r requirements.txt
11-
# Use `INC_PT_ONLY=1 pip install git+https://github.com/intel/neural-compressor.git@v3.6rc` for the latest updates before neural-compressor v3.6 release
12-
pip install neural-compressor-pt==3.6
13-
# Use `pip install git+https://github.com/intel/auto-round.git@v0.8.0rc2` for the latest updates before auto-round v0.8.0 release
14-
pip install auto-round==0.8.0
11+
# Use `INC_PT_ONLY=1 pip install git+https://github.com/intel/neural-compressor.git@master` for the latest updates before neural-compressor v3.6 release
12+
pip install neural-compressor-pt==3.7
13+
# Use `pip install git+https://github.com/intel/auto-round.git@main` for the latest updates before auto-round v0.8.0 release
14+
pip install auto-round==0.9.3
1515
```
1616

1717
## 2. Prepare Model

examples/pytorch/diffusion_model/diffusers/flux/main.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ def inference_worker(eval_file, pipe, image_save_dir):
8181
output.images[idx].save(os.path.join(image_save_dir, str(image_id) + ".png"))
8282

8383

84-
def tune():
85-
pipe = AutoPipelineForText2Image.from_pretrained(args.model, torch_dtype=torch.bfloat16)
84+
def tune(device):
85+
pipe = AutoPipelineForText2Image.from_pretrained(args.model, torch_dtype=torch.bfloat16).to(device)
8686
model = pipe.transformer
8787
layer_config = {}
8888
kwargs = {}
@@ -116,7 +116,7 @@ def tune():
116116

117117
if args.quantize:
118118
print(f"Start to quantize {args.model}.")
119-
tune()
119+
tune(device)
120120
exit(0)
121121

122122
if args.inference:

examples/pytorch/diffusion_model/diffusers/flux/run_benchmark.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ function run_benchmark {
8686

8787
echo "Start calculating final score..."
8888

89-
python3 main.py --output_image_path ${output_image_path} --accuracy
89+
python3 main.py --output_image_path ${output_image_path} --accuracy --eval_dataset ${dataset_location}
9090
}
9191

9292
main "$@"

0 commit comments

Comments
 (0)