Skip to content

Commit 4eca0c5

Browse files
authored
RestServer - support inference/training jobType in job protocal (#105)
* update * check job protocol for inference job * support jobType when query jobs * update * fix tag filter bug * update * update * update * update
1 parent 6c49415 commit 4eca0c5

File tree

4 files changed

+91
-28
lines changed

4 files changed

+91
-28
lines changed

src/rest-server/src/config/v2/protocol.js

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,11 @@ const protocolSchema = {
173173
},
174174
minItems: 1,
175175
},
176+
jobType: {
177+
type: 'string',
178+
enum: ['inference', 'training', 'others'],
179+
default: 'others',
180+
},
176181
parameters: {
177182
type: 'object',
178183
additionalProperties: true,

src/rest-server/src/controllers/v2/job.js

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,24 @@ const list = asyncHandler(async (req, res) => {
8686
if ('tagsNotContain' in req.query) {
8787
tagsNotContainFilter.name = req.query.tagsNotContain.split(',');
8888
}
89+
if ('jobType' in req.query) {
90+
// validate jobType values
91+
const validJobTypes = ['inference', 'training', 'others'];
92+
const requestedTypes = req.query.jobType.split(',');
93+
const invalidTypes = requestedTypes.filter(type => !validJobTypes.includes(type));
94+
if (invalidTypes.length > 0) {
95+
throw createError(
96+
'Bad Request',
97+
'InvalidParametersError',
98+
`Invalid job type(s): ${invalidTypes.join(', ')}`
99+
);
100+
}
101+
if (Array.isArray(tagsContainFilter.name)) {
102+
tagsContainFilter.name.push(...requestedTypes);
103+
} else {
104+
tagsContainFilter.name = requestedTypes;
105+
}
106+
}
89107
if ('keyword' in req.query) {
90108
// match text in username, jobname, or vc
91109
filters[Op.or] = [
@@ -199,6 +217,7 @@ const update = asyncHandler(async (req, res) => {
199217
const jobName = res.locals.protocol.name;
200218
const userName = req.user.username;
201219
const frameworkName = `${userName}~${jobName}`;
220+
const jobType = res.locals.protocol.jobType || 'others';
202221

203222
// check duplicate job
204223
try {
@@ -216,6 +235,7 @@ const update = asyncHandler(async (req, res) => {
216235
}
217236
}
218237
await job.put(frameworkName, res.locals.protocol, req.body);
238+
await job.addTag(frameworkName, jobType);
219239
res.status(status('Accepted')).json({
220240
status: status('Accepted'),
221241
message: `Update job ${jobName} for user ${userName} successfully.`,
@@ -368,10 +388,10 @@ const getLogs = asyncHandler(async (req, res) => {
368388
throw error.code === 'NoTaskLogError'
369389
? error
370390
: createError(
371-
'Internal Server Error',
372-
'UnknownError',
373-
'Failed to get log list',
374-
);
391+
'Internal Server Error',
392+
'UnknownError',
393+
'Failed to get log list',
394+
);
375395
}
376396
});
377397

src/rest-server/src/middlewares/v2/protocol.js

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,34 @@ const protocolValidate = (protocolYAML) => {
120120
}
121121
}
122122
}
123+
124+
// check jobType
125+
if ('jobType' in protocolObj) {
126+
if (protocolObj.jobType === 'inference') {
127+
// check parameters for inference job
128+
if (!('parameters' in protocolObj)) {
129+
throw createError(
130+
'Bad Request',
131+
'InvalidProtocolError',
132+
`The following parameters must be specified for inference job:
133+
INTERNAL_SERVER_IP=$PAI_HOST_IP_taskrole_0
134+
INTERNAL_SERVER_PORT=$PAI_PORT_LIST_taskrole_0_http
135+
API_KEY=[any string]`,
136+
);
137+
}
138+
const requiredParams = ['INTERNAL_SERVER_IP', 'INTERNAL_SERVER_PORT', 'API_KEY'];
139+
for (const param of requiredParams) {
140+
if (!(param in protocolObj.parameters)) {
141+
throw createError(
142+
'Bad Request',
143+
'InvalidProtocolError',
144+
`Parameter ${param} must be specified for inference job.`,
145+
);
146+
}
147+
}
148+
}
149+
}
150+
123151
for (const taskRole of Object.keys(protocolObj.taskRoles)) {
124152
for (const field of prerequisiteFields) {
125153
if (
@@ -194,7 +222,7 @@ const protocolRender = (protocolObj) => {
194222
],
195223
$output:
196224
protocolObj.prerequisites.output[
197-
protocolObj.taskRoles[taskRole].output
225+
protocolObj.taskRoles[taskRole].output
198226
],
199227
$data:
200228
protocolObj.prerequisites.data[protocolObj.taskRoles[taskRole].data],

src/rest-server/src/models/v2/job/k8s.js

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1000,33 +1000,43 @@ const list = async (
10001000
Object.keys(tagsContainFilter).length !== 0 ||
10011001
Object.keys(tagsNotContainFilter).length !== 0
10021002
) {
1003-
filters.name = {};
1004-
// tagsContain
1003+
// Build name filters by querying Tag table directly to avoid using QueryGenerator
1004+
const nameFilter = {};
1005+
1006+
// tagsContain -> include frameworks whose name appears in Tag rows matched by tagsContainFilter
10051007
if (Object.keys(tagsContainFilter).length !== 0) {
1006-
const queryContainFrameworkName = databaseModel.sequelize.dialect.QueryGenerator.selectQuery(
1007-
'tags',
1008-
{
1009-
attributes: ['frameworkName'],
1010-
where: tagsContainFilter,
1011-
},
1012-
);
1013-
filters.name[Sequelize.Op.in] = Sequelize.literal(`
1014-
(${queryContainFrameworkName.slice(0, -1)})
1015-
`);
1008+
const containRows = await databaseModel.Tag.findAll({
1009+
attributes: ['frameworkName'],
1010+
where: tagsContainFilter,
1011+
raw: true,
1012+
});
1013+
const containNames = [...new Set(containRows.map((r) => r.frameworkName))];
1014+
// if no tags match, result is empty
1015+
if (containNames.length === 0) {
1016+
if (withTotalCount) {
1017+
return { totalCount: 0, data: [] };
1018+
} else {
1019+
return [];
1020+
}
1021+
}
1022+
nameFilter[Sequelize.Op.in] = containNames;
10161023
}
1017-
// tagsNotContain
1024+
1025+
// tagsNotContain -> exclude frameworks whose name appears in Tag rows matched by tagsNotContainFilter
10181026
if (Object.keys(tagsNotContainFilter).length !== 0) {
1019-
const queryNotContainFrameworkName = databaseModel.sequelize.dialect.QueryGenerator.selectQuery(
1020-
'tags',
1021-
{
1022-
attributes: ['frameworkName'],
1023-
where: tagsNotContainFilter,
1024-
},
1025-
);
1026-
filters.name[Sequelize.Op.notIn] = Sequelize.literal(`
1027-
(${queryNotContainFrameworkName.slice(0, -1)})
1028-
`);
1027+
const notContainRows = await databaseModel.Tag.findAll({
1028+
attributes: ['frameworkName'],
1029+
where: tagsNotContainFilter,
1030+
raw: true,
1031+
});
1032+
const notContainNames = [...new Set(notContainRows.map((r) => r.frameworkName))];
1033+
if (notContainNames.length > 0) {
1034+
nameFilter[Sequelize.Op.notIn] = notContainNames;
1035+
}
10291036
}
1037+
1038+
// merge with any existing name filter
1039+
filters.name = Object.assign({}, filters.name || {}, nameFilter);
10301040
}
10311041

10321042
frameworks = await databaseModel.Framework.findAll({

0 commit comments

Comments
 (0)