@@ -24,6 +24,8 @@ const hived = require('@pai/middlewares/v2/hived');
2424const { enabledHived } = require ( '@pai/config/launcher' ) ;
2525const protocolSchema = require ( '@pai/config/v2/protocol' ) ;
2626const asyncHandler = require ( '@pai/middlewares/v2/asyncHandler' ) ;
27+ const logger = require ( '@pai/config/logger' ) ;
28+ const launcherConfig = require ( '@pai/config/launcher' ) ;
2729
2830const mustacheWriter = new mustache . Writer ( ) ;
2931
@@ -53,6 +55,19 @@ const render = (template, dict, tags = ['<%', '%>']) => {
5355 return result . trim ( ) ;
5456} ;
5557
58+ const getImageName = ( prerequisites ) => {
59+ if (
60+ typeof prerequisites !== 'object' ||
61+ typeof prerequisites . dockerimage !== 'object'
62+ ) {
63+ return null ;
64+ }
65+
66+ const dockerImages = Object . values ( prerequisites . dockerimage ) ;
67+ const img = dockerImages . find ( p => p . type === 'dockerimage' && p . uri ) ;
68+ return img ?. uri ?? null ;
69+ } ;
70+
5671const protocolValidate = ( protocolYAML ) => {
5772 const protocolObj = yaml . load ( protocolYAML ) ;
5873 if ( ! protocolSchema . validate ( protocolObj ) ) {
@@ -87,6 +102,42 @@ const protocolValidate = (protocolYAML) => {
87102 prerequisiteSet . add ( item . name ) ;
88103 }
89104 }
105+ const imageRegexPattern = launcherConfig . imageRegex ;
106+ let imageRegex ;
107+
108+ try {
109+ if ( imageRegexPattern && imageRegexPattern . length > 0 ) {
110+ imageRegex = new RegExp ( imageRegexPattern ) ;
111+ }
112+ } catch ( error ) {
113+ logger . info ( `Invalid imageRegex pattern: ${ imageRegexPattern } . Error: ${ error . message } ` ) ;
114+ throw createError (
115+ 'Internal Server Error' ,
116+ 'InvalidImageRegexError' ,
117+ `The provided imageRegex pattern "${ imageRegexPattern } " is invalid.`
118+ ) ;
119+ }
120+
121+ if ( imageRegex ) {
122+ const imageName = getImageName ( protocolObj . prerequisites ) ;
123+ // Check if the imageName matches the imageRegex
124+ if ( ! imageName ) {
125+ throw createError (
126+ 'Bad Request' ,
127+ 'NoDockerImageError' ,
128+ 'No valid docker image found in prerequisites.'
129+ ) ;
130+ }
131+ // Check if the imageName matches the imageRegex
132+ const match = imageRegex . test ( imageName ) ;
133+ if ( ! match ) {
134+ throw createError (
135+ 'Bad Request' ,
136+ 'InvalidImageError' ,
137+ `The image ${ imageName } is not allowed.`
138+ ) ;
139+ }
140+ }
90141 }
91142 protocolObj . prerequisites = prerequisites ;
92143 // convert deployments list to dict
@@ -120,6 +171,34 @@ const protocolValidate = (protocolYAML) => {
120171 }
121172 }
122173 }
174+
175+ // check jobType
176+ if ( 'jobType' in protocolObj ) {
177+ if ( protocolObj . jobType === 'inference' ) {
178+ // check parameters for inference job
179+ if ( ! ( 'parameters' in protocolObj ) ) {
180+ throw createError (
181+ 'Bad Request' ,
182+ 'InvalidProtocolError' ,
183+ `The following parameters must be specified for inference job:
184+ INTERNAL_SERVER_IP=$PAI_HOST_IP_taskrole_0
185+ INTERNAL_SERVER_PORT=$PAI_PORT_LIST_taskrole_0_http
186+ API_KEY=[any string]` ,
187+ ) ;
188+ }
189+ const requiredParams = [ 'INTERNAL_SERVER_IP' , 'INTERNAL_SERVER_PORT' , 'API_KEY' ] ;
190+ for ( const param of requiredParams ) {
191+ if ( ! ( param in protocolObj . parameters ) ) {
192+ throw createError (
193+ 'Bad Request' ,
194+ 'InvalidProtocolError' ,
195+ `Parameter ${ param } must be specified for inference job.` ,
196+ ) ;
197+ }
198+ }
199+ }
200+ }
201+
123202 for ( const taskRole of Object . keys ( protocolObj . taskRoles ) ) {
124203 for ( const field of prerequisiteFields ) {
125204 if (
@@ -194,7 +273,7 @@ const protocolRender = (protocolObj) => {
194273 ] ,
195274 $output :
196275 protocolObj . prerequisites . output [
197- protocolObj . taskRoles [ taskRole ] . output
276+ protocolObj . taskRoles [ taskRole ] . output
198277 ] ,
199278 $data :
200279 protocolObj . prerequisites . data [ protocolObj . taskRoles [ taskRole ] . data ] ,
0 commit comments