Skip to content

Commit ecdae7c

Browse files
christian-pintoDarkLight1337
authored andcommitted
[Misc] Terratorch related fixes (vllm-project#24337)
Signed-off-by: Christian Pinto <[email protected]> Co-authored-by: Cyrus Leung <[email protected]>
1 parent 597f586 commit ecdae7c

File tree

11 files changed

+18
-37
lines changed

11 files changed

+18
-37
lines changed

examples/offline_inference/prithvi_geospatial_mae_io_processor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
def main():
2020
torch.set_default_dtype(torch.float16)
21-
image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/India_900498_S2Hand.tif" # noqa: E501
21+
image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/valencia_example_2024-10-26.tiff" # noqa: E501
2222

2323
img_prompt = dict(
2424
data=image_url,
@@ -36,7 +36,7 @@ def main():
3636
# to avoid the model going OOM.
3737
# The maximum number depends on the available GPU memory
3838
max_num_seqs=32,
39-
io_processor_plugin="prithvi_to_tiff_india",
39+
io_processor_plugin="prithvi_to_tiff",
4040
model_impl="terratorch",
4141
)
4242

examples/online_serving/prithvi_geospatial_mae.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@
1818
# --model-impl terratorch
1919
# --task embed --trust-remote-code
2020
# --skip-tokenizer-init --enforce-eager
21-
# --io-processor-plugin prithvi_to_tiff_india
21+
# --io-processor-plugin prithvi_to_tiff
2222

2323

2424
def main():
25-
image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/India_900498_S2Hand.tif" # noqa: E501
25+
image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/valencia_example_2024-10-26.tiff" # noqa: E501
2626
server_endpoint = "http://localhost:8000/pooling"
2727

2828
request_payload_url = {

requirements/test.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,4 +54,4 @@ runai-model-streamer-s3==0.11.0
5454
fastsafetensors>=0.1.10
5555
pydantic>=2.10 # 2.9 leads to error on python 3.10
5656
decord==0.6.0
57-
terratorch==1.1rc3 # required for PrithviMAE test
57+
terratorch @ git+https://github.com/IBM/[email protected] # required for PrithviMAE test

requirements/test.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1042,7 +1042,7 @@ tensorboardx==2.6.4
10421042
# via lightning
10431043
tensorizer==2.10.1
10441044
# via -r requirements/test.in
1045-
terratorch==1.1rc3
1045+
terratorch @ git+https://github.com/IBM/terratorch.git@07184fcf91a1324f831ff521dd238d97fe350e3e
10461046
# via -r requirements/test.in
10471047
threadpoolctl==3.5.0
10481048
# via scikit-learn

tests/entrypoints/openai/test_skip_tokenizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from ...utils import RemoteOpenAIServer
1313

14-
MODEL_NAME = "mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11"
14+
MODEL_NAME = "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11"
1515
DTYPE = "float16"
1616

1717

tests/models/registry.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -383,15 +383,15 @@ def check_available_online(
383383
"Phi3VForCausalLM": _HfExamplesInfo("TIGER-Lab/VLM2Vec-Full",
384384
trust_remote_code=True),
385385
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("MrLight/dse-qwen2-2b-mrl-v1"), # noqa: E501
386-
"PrithviGeoSpatialMAE": _HfExamplesInfo("mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11", # noqa: E501
386+
"PrithviGeoSpatialMAE": _HfExamplesInfo("ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11", # noqa: E501
387387
dtype=torch.float16,
388388
enforce_eager=True,
389389
skip_tokenizer_init=True,
390390
# This is to avoid the model
391391
# going OOM in CI
392392
max_num_seqs=32,
393393
),
394-
"Terratorch": _HfExamplesInfo("mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11",
394+
"Terratorch": _HfExamplesInfo("ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11", # noqa: E501
395395
dtype=torch.float16,
396396
enforce_eager=True,
397397
skip_tokenizer_init=True,

tests/models/test_terratorch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
@pytest.mark.parametrize(
1212
"model",
1313
[
14-
"mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11",
14+
"ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11",
1515
"mgazz/Prithvi_v2_eo_300_tl_unet_agb"
1616
],
1717
)
Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3-
def register_prithvi_india():
4-
return "prithvi_io_processor.prithvi_processor.PrithviMultimodalDataProcessorIndia" # noqa: E501
53

64

7-
def register_prithvi_valencia():
8-
return "prithvi_io_processor.prithvi_processor.PrithviMultimodalDataProcessorValencia" # noqa: E501
5+
def register_prithvi():
6+
return "prithvi_io_processor.prithvi_processor.PrithviMultimodalDataProcessor" # noqa: E501

tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/prithvi_processor.py

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,8 @@ def load_image(
234234

235235
class PrithviMultimodalDataProcessor(IOProcessor):
236236

237+
indices = [0, 1, 2, 3, 4, 5]
238+
237239
def __init__(self, vllm_config: VllmConfig):
238240

239241
super().__init__(vllm_config)
@@ -412,21 +414,3 @@ def post_process(
412414
format="tiff",
413415
data=out_data,
414416
request_id=request_id)
415-
416-
417-
class PrithviMultimodalDataProcessorIndia(PrithviMultimodalDataProcessor):
418-
419-
def __init__(self, vllm_config: VllmConfig):
420-
421-
super().__init__(vllm_config)
422-
423-
self.indices = [1, 2, 3, 8, 11, 12]
424-
425-
426-
class PrithviMultimodalDataProcessorValencia(PrithviMultimodalDataProcessor):
427-
428-
def __init__(self, vllm_config: VllmConfig):
429-
430-
super().__init__(vllm_config)
431-
432-
self.indices = [0, 1, 2, 3, 4, 5]

tests/plugins/prithvi_io_processor_plugin/setup.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@
99
packages=["prithvi_io_processor"],
1010
entry_points={
1111
"vllm.io_processor_plugins": [
12-
"prithvi_to_tiff_india = prithvi_io_processor:register_prithvi_india", # noqa: E501
13-
"prithvi_to_tiff_valencia = prithvi_io_processor:register_prithvi_valencia", # noqa: E501
12+
"prithvi_to_tiff = prithvi_io_processor:register_prithvi", # noqa: E501
1413
]
1514
},
1615
)

0 commit comments

Comments
 (0)