Skip to content

Commit 43b85f4

Browse files
authored
Merge branch 'dev' into ruigao/cilium-update
2 parents 9e3edcb + adfb217 commit 43b85f4

File tree

16 files changed

+274
-69
lines changed

16 files changed

+274
-69
lines changed

.github/workflows/build-all.yaml

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,16 @@ permissions:
55

66
on:
77
push:
8-
branches: [main, 'release/*']
9-
pull_request:
10-
branches: [main, dev, 'release/*']
8+
branches: ['release/*']
119
release:
1210
types: [published]
1311
workflow_dispatch:
12+
inputs:
13+
branch:
14+
description: 'The branch name or tag to run the workflow on'
15+
required: true
16+
default: 'main'
17+
type: string
1418

1519
env:
1620
TAG: ${{ github.run_number }}
@@ -36,7 +40,7 @@ jobs:
3640
with:
3741
fetch-depth: 0
3842
submodules: false
39-
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.ref_name }}
43+
ref: ${{ github.event.inputs.branch || github.ref }}
4044

4145
- name: Get All Services
4246
id: all

.github/workflows/build-deploy-changes.yaml

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -171,13 +171,27 @@ jobs:
171171
--overwrite-existing
172172
kubelogin convert-kubeconfig -l azurecli
173173
kubectl config use-context ${{ secrets.KUBERNETES_CLUSTER }}
174+
# Replace "webportal" with "webportal-dind" if "webportal" is changed
175+
services_to_deploy="${{ steps.changes.outputs.folders }}"
176+
if echo " $services_to_deploy " | grep -q " webportal "; then
177+
tmp=""
178+
for s in $services_to_deploy; do
179+
[ "$s" = "webportal" ] && continue
180+
[ "$s" = "webportal-dind" ] && continue
181+
tmp="$tmp $s"
182+
done
183+
services_to_deploy="$tmp webportal-dind"
184+
services_to_deploy=$(echo "$services_to_deploy" | xargs)
185+
fi
186+
echo "Final services to deploy: $services_to_deploy"
187+
174188
echo "${{ secrets.PAI_CLUSTER_NAME }}" > cluster_id
175-
echo "Stopping changed pai services \"${{ steps.changes.outputs.folders }}\" on ${{ secrets.PAI_CLUSTER_NAME }} ..."
176-
$GITHUB_WORKSPACE/paictl.py service stop -n ${{ steps.changes.outputs.folders }} < cluster_id
189+
echo "Stopping changed pai services $services_to_deploy on ${{ secrets.PAI_CLUSTER_NAME }} ..."
190+
$GITHUB_WORKSPACE/paictl.py service stop -n $services_to_deploy < cluster_id
177191
echo "Pushing config to cluster \"${{ secrets.PAI_CLUSTER_NAME }}\" ..."
178-
$GITHUB_WORKSPACE/paictl.py config push -m service -p $GITHUB_WORKSPACE/config/cluster-configuration < cluster_id
179-
echo "Starting to update \"${{ steps.changes.outputs.folders }}\" on ${{ secrets.PAI_CLUSTER_NAME }} ..."
180-
$GITHUB_WORKSPACE/paictl.py service start -n ${{ steps.changes.outputs.folders }} < cluster_id
192+
$GITHUB_WORKSPACE/paictl.py config push -m service -p $GITHUB_WORKSPACE/config/cluster-configuration < cluster_id
193+
echo "Starting to update $services_to_deploy on ${{ secrets.PAI_CLUSTER_NAME }} ..."
194+
$GITHUB_WORKSPACE/paictl.py service start -n $services_to_deploy < cluster_id
181195
kubectl get pod
182196
kubectl get service
183197

src/rest-server/deploy/rest-server.yaml.template

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,8 @@ spec:
174174
value: {{ cluster_cfg['rest-server']['hived-computing-device-envs'] }}
175175
- name: ALERT_MANAGER_URL
176176
value: "{{ cluster_cfg['alert-manager']['url'] }}"
177+
- name: IMAGE_REGEX
178+
value: "{{ cluster_cfg['rest-server']['image-regex'] | default("") }}"
177179
ports:
178180
- name: rest-server
179181
containerPort: 8080

src/rest-server/src/config/launcher.js

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ const k8sLauncherConfigSchema = Joi.object()
8484
jobRestrictionGitScriptName: Joi.string().required(),
8585
clstoreHostPath: Joi.string().empty(''),
8686
clstoreJobPath: Joi.string().empty(''),
87+
imageRegex: Joi.string().empty(''),
8788
})
8889
.required();
8990

@@ -156,6 +157,7 @@ if (launcherType === 'k8s') {
156157
jobRestrictionGitScriptName: process.env.JOB_RESTRICTION_GIT_SCRIPT_NAME || 'unset',
157158
clstoreHostPath: process.env.CLUSTER_LOCAL_STORAGE_HOST_PATH,
158159
clstoreJobPath: process.env.CLUSTER_LOCAL_STORAGE_JOB_PATH,
160+
imageRegex: process.env.IMAGE_REGEX || '',
159161
};
160162

161163
const { error, value } = k8sLauncherConfigSchema.validate(launcherConfig);

src/rest-server/src/config/v2/protocol.js

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,11 @@ const protocolSchema = {
173173
},
174174
minItems: 1,
175175
},
176+
jobType: {
177+
type: 'string',
178+
enum: ['inference', 'training', 'others'],
179+
default: 'others',
180+
},
176181
parameters: {
177182
type: 'object',
178183
additionalProperties: true,

src/rest-server/src/controllers/v2/job.js

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,24 @@ const list = asyncHandler(async (req, res) => {
8686
if ('tagsNotContain' in req.query) {
8787
tagsNotContainFilter.name = req.query.tagsNotContain.split(',');
8888
}
89+
if ('jobType' in req.query) {
90+
// validate jobType values
91+
const validJobTypes = ['inference', 'training', 'others'];
92+
const requestedTypes = req.query.jobType.split(',');
93+
const invalidTypes = requestedTypes.filter(type => !validJobTypes.includes(type));
94+
if (invalidTypes.length > 0) {
95+
throw createError(
96+
'Bad Request',
97+
'InvalidParametersError',
98+
`Invalid job type(s): ${invalidTypes.join(', ')}`
99+
);
100+
}
101+
if (Array.isArray(tagsContainFilter.name)) {
102+
tagsContainFilter.name.push(...requestedTypes);
103+
} else {
104+
tagsContainFilter.name = requestedTypes;
105+
}
106+
}
89107
if ('keyword' in req.query) {
90108
// match text in username, jobname, or vc
91109
filters[Op.or] = [
@@ -199,6 +217,7 @@ const update = asyncHandler(async (req, res) => {
199217
const jobName = res.locals.protocol.name;
200218
const userName = req.user.username;
201219
const frameworkName = `${userName}~${jobName}`;
220+
const jobType = res.locals.protocol.jobType || 'others';
202221

203222
// check duplicate job
204223
try {
@@ -216,6 +235,7 @@ const update = asyncHandler(async (req, res) => {
216235
}
217236
}
218237
await job.put(frameworkName, res.locals.protocol, req.body);
238+
await job.addTag(frameworkName, jobType);
219239
res.status(status('Accepted')).json({
220240
status: status('Accepted'),
221241
message: `Update job ${jobName} for user ${userName} successfully.`,
@@ -368,10 +388,10 @@ const getLogs = asyncHandler(async (req, res) => {
368388
throw error.code === 'NoTaskLogError'
369389
? error
370390
: createError(
371-
'Internal Server Error',
372-
'UnknownError',
373-
'Failed to get log list',
374-
);
391+
'Internal Server Error',
392+
'UnknownError',
393+
'Failed to get log list',
394+
);
375395
}
376396
});
377397

src/rest-server/src/middlewares/v2/protocol.js

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ const hived = require('@pai/middlewares/v2/hived');
2424
const { enabledHived } = require('@pai/config/launcher');
2525
const protocolSchema = require('@pai/config/v2/protocol');
2626
const asyncHandler = require('@pai/middlewares/v2/asyncHandler');
27+
const logger = require('@pai/config/logger');
28+
const launcherConfig = require('@pai/config/launcher');
2729

2830
const mustacheWriter = new mustache.Writer();
2931

@@ -53,6 +55,19 @@ const render = (template, dict, tags = ['<%', '%>']) => {
5355
return result.trim();
5456
};
5557

58+
const getImageName = (prerequisites) => {
59+
if (
60+
typeof prerequisites !== 'object' ||
61+
typeof prerequisites.dockerimage !== 'object'
62+
) {
63+
return null;
64+
}
65+
66+
const dockerImages = Object.values(prerequisites.dockerimage);
67+
const img = dockerImages.find(p => p.type === 'dockerimage' && p.uri);
68+
return img?.uri ?? null;
69+
};
70+
5671
const protocolValidate = (protocolYAML) => {
5772
const protocolObj = yaml.load(protocolYAML);
5873
if (!protocolSchema.validate(protocolObj)) {
@@ -87,6 +102,42 @@ const protocolValidate = (protocolYAML) => {
87102
prerequisiteSet.add(item.name);
88103
}
89104
}
105+
const imageRegexPattern = launcherConfig.imageRegex;
106+
let imageRegex;
107+
108+
try {
109+
if (imageRegexPattern && imageRegexPattern.length > 0) {
110+
imageRegex = new RegExp(imageRegexPattern);
111+
}
112+
} catch (error) {
113+
logger.info(`Invalid imageRegex pattern: ${imageRegexPattern}. Error: ${error.message}`);
114+
throw createError(
115+
'Internal Server Error',
116+
'InvalidImageRegexError',
117+
`The provided imageRegex pattern "${imageRegexPattern}" is invalid.`
118+
);
119+
}
120+
121+
if (imageRegex) {
122+
const imageName = getImageName(protocolObj.prerequisites);
123+
// Check if the imageName matches the imageRegex
124+
if (!imageName) {
125+
throw createError(
126+
'Bad Request',
127+
'NoDockerImageError',
128+
'No valid docker image found in prerequisites.'
129+
);
130+
}
131+
// Check if the imageName matches the imageRegex
132+
const match = imageRegex.test(imageName);
133+
if (!match) {
134+
throw createError(
135+
'Bad Request',
136+
'InvalidImageError',
137+
`The image ${imageName} is not allowed.`
138+
);
139+
}
140+
}
90141
}
91142
protocolObj.prerequisites = prerequisites;
92143
// convert deployments list to dict
@@ -120,6 +171,34 @@ const protocolValidate = (protocolYAML) => {
120171
}
121172
}
122173
}
174+
175+
// check jobType
176+
if ('jobType' in protocolObj) {
177+
if (protocolObj.jobType === 'inference') {
178+
// check parameters for inference job
179+
if (!('parameters' in protocolObj)) {
180+
throw createError(
181+
'Bad Request',
182+
'InvalidProtocolError',
183+
`The following parameters must be specified for inference job:
184+
INTERNAL_SERVER_IP=$PAI_HOST_IP_taskrole_0
185+
INTERNAL_SERVER_PORT=$PAI_PORT_LIST_taskrole_0_http
186+
API_KEY=[any string]`,
187+
);
188+
}
189+
const requiredParams = ['INTERNAL_SERVER_IP', 'INTERNAL_SERVER_PORT', 'API_KEY'];
190+
for (const param of requiredParams) {
191+
if (!(param in protocolObj.parameters)) {
192+
throw createError(
193+
'Bad Request',
194+
'InvalidProtocolError',
195+
`Parameter ${param} must be specified for inference job.`,
196+
);
197+
}
198+
}
199+
}
200+
}
201+
123202
for (const taskRole of Object.keys(protocolObj.taskRoles)) {
124203
for (const field of prerequisiteFields) {
125204
if (
@@ -194,7 +273,7 @@ const protocolRender = (protocolObj) => {
194273
],
195274
$output:
196275
protocolObj.prerequisites.output[
197-
protocolObj.taskRoles[taskRole].output
276+
protocolObj.taskRoles[taskRole].output
198277
],
199278
$data:
200279
protocolObj.prerequisites.data[protocolObj.taskRoles[taskRole].data],

src/rest-server/src/middlewares/v2/quota.js

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -75,19 +75,6 @@ const getReqeustedGpuCount = (protocol) => {
7575
return gpuCount;
7676
};
7777

78-
function getImageName(prerequisites) {
79-
if (
80-
typeof prerequisites !== 'object' ||
81-
typeof prerequisites.dockerimage !== 'object'
82-
) {
83-
return null;
84-
}
85-
86-
const dockerImages = Object.values(prerequisites.dockerimage);
87-
const img = dockerImages.find(p => p.type === 'dockerimage' && p.uri);
88-
return img?.uri ?? null;
89-
}
90-
9178
const getJobPriority = (protocol) => {
9279
return protocol.extras?.hivedScheduler?.jobPriorityClass || null;
9380
};
@@ -104,21 +91,6 @@ const check = async (req, res, next) => {
10491
const userPrioritySet = userInfo.extension?.jobPriority ?? null;
10592
const userPriorityExpiration = userInfo.extension?.jobExpiration ?? null;
10693

107-
const imageName = getImageName(jobProtocol.prerequisites);
108-
const imageRepo = imageName.split('/')[0].trim().toLowerCase();
109-
const isAzureCR = /\.azurecr\.io$/i.test(imageRepo);
110-
// Reject if it’s not ACR
111-
if (!isAzureCR) {
112-
return next(
113-
createError(
114-
'Forbidden',
115-
'InvalidImageError',
116-
`The image ${imageRepo || imageName} is not allowed. ` +
117-
`Please use an image from Azure Container Registry (ACR).`
118-
)
119-
);
120-
}
121-
12294
if (jobPriority === 'prod') {
12395
if (userPrioritySet !== 1) {
12496
logger.debug(

src/rest-server/src/models/v2/job/k8s.js

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1000,33 +1000,43 @@ const list = async (
10001000
Object.keys(tagsContainFilter).length !== 0 ||
10011001
Object.keys(tagsNotContainFilter).length !== 0
10021002
) {
1003-
filters.name = {};
1004-
// tagsContain
1003+
// Build name filters by querying Tag table directly to avoid using QueryGenerator
1004+
const nameFilter = {};
1005+
1006+
// tagsContain -> include frameworks whose name appears in Tag rows matched by tagsContainFilter
10051007
if (Object.keys(tagsContainFilter).length !== 0) {
1006-
const queryContainFrameworkName = databaseModel.sequelize.dialect.QueryGenerator.selectQuery(
1007-
'tags',
1008-
{
1009-
attributes: ['frameworkName'],
1010-
where: tagsContainFilter,
1011-
},
1012-
);
1013-
filters.name[Sequelize.Op.in] = Sequelize.literal(`
1014-
(${queryContainFrameworkName.slice(0, -1)})
1015-
`);
1008+
const containRows = await databaseModel.Tag.findAll({
1009+
attributes: ['frameworkName'],
1010+
where: tagsContainFilter,
1011+
raw: true,
1012+
});
1013+
const containNames = [...new Set(containRows.map((r) => r.frameworkName))];
1014+
// if no tags match, result is empty
1015+
if (containNames.length === 0) {
1016+
if (withTotalCount) {
1017+
return { totalCount: 0, data: [] };
1018+
} else {
1019+
return [];
1020+
}
1021+
}
1022+
nameFilter[Sequelize.Op.in] = containNames;
10161023
}
1017-
// tagsNotContain
1024+
1025+
// tagsNotContain -> exclude frameworks whose name appears in Tag rows matched by tagsNotContainFilter
10181026
if (Object.keys(tagsNotContainFilter).length !== 0) {
1019-
const queryNotContainFrameworkName = databaseModel.sequelize.dialect.QueryGenerator.selectQuery(
1020-
'tags',
1021-
{
1022-
attributes: ['frameworkName'],
1023-
where: tagsNotContainFilter,
1024-
},
1025-
);
1026-
filters.name[Sequelize.Op.notIn] = Sequelize.literal(`
1027-
(${queryNotContainFrameworkName.slice(0, -1)})
1028-
`);
1027+
const notContainRows = await databaseModel.Tag.findAll({
1028+
attributes: ['frameworkName'],
1029+
where: tagsNotContainFilter,
1030+
raw: true,
1031+
});
1032+
const notContainNames = [...new Set(notContainRows.map((r) => r.frameworkName))];
1033+
if (notContainNames.length > 0) {
1034+
nameFilter[Sequelize.Op.notIn] = notContainNames;
1035+
}
10291036
}
1037+
1038+
// merge with any existing name filter
1039+
filters.name = Object.assign({}, filters.name || {}, nameFilter);
10301040
}
10311041

10321042
frameworks = await databaseModel.Framework.findAll({

src/rest-server/src/utils/error.d.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,11 @@ declare type Code =
3333
'ForbiddenUserError' |
3434
'ForbiddenKeyError' |
3535
'IncorrectPasswordError' |
36+
'InvalidImageError' |
37+
'InvalidImageRegexError' |
3638
'InvalidParametersError' |
3739
'NoApiError' |
40+
'NoDockerImageError' |
3841
'NoJobError' |
3942
'NoJobConfigError' |
4043
'NoJobSshInfoError' |

0 commit comments

Comments
 (0)