Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions application/ui/mocks/mock-pipeline.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import { SchemaPipeline } from './../src/api/openapi-spec.d';

export const getMockedPipeline = (customPipeline?: Partial<SchemaPipeline>): SchemaPipeline => {
return {
project_id: '123',
status: 'running' as const,
source: {
id: 'source-id',
name: 'source',
project_id: '123',
source_type: 'video_file' as const,
video_path: 'video.mp4',
},
model: {
id: '1',
name: 'Object_Detection_TestModel',
format: 'onnx' as const,
project_id: '123',
threshold: 0.5,
is_ready: true,
train_job_id: 'train-job-1',
},
sink: {
id: 'sink-id',
name: 'sink',
project_id: '123',
folder_path: 'data/sink',
output_formats: ['image_original', 'image_with_predictions', 'predictions'] as Array<
'image_original' | 'image_with_predictions' | 'predictions'
>,
rate_limit: 0.2,
sink_type: 'folder' as const,
},
...customPipeline,
};
};
115 changes: 115 additions & 0 deletions application/ui/src/features/inspect/stream/stream-container.test.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import { QueryClient, QueryClientProvider } from '@tanstack/react-query';
import { render, screen } from '@testing-library/react';
import userEvent from '@testing-library/user-event';
import { HttpResponse } from 'msw';
import { MemoryRouter, Route, Routes } from 'react-router-dom';
import { SchemaPipeline } from 'src/api/openapi-spec';
import { http } from 'src/api/utils';
import { ZoomProvider } from 'src/components/zoom/zoom';
import { server } from 'src/msw-node-setup';

import { getMockedPipeline } from '../../../../mocks/mock-pipeline';
import { useWebRTCConnection, WebRTCConnectionState } from '../../../components/stream/web-rtc-connection-provider';
import { StreamContainer } from './stream-container';

vi.mock('../../../components/stream/web-rtc-connection-provider', () => ({
useWebRTCConnection: vi.fn(),
}));

describe('StreamContainer', () => {
const renderApp = ({
webRtcConfig = {},
pipelineConfig = {},
}: {
webRtcConfig?: Partial<WebRTCConnectionState>;
pipelineConfig?: Partial<SchemaPipeline>;
}) => {
vi.mocked(useWebRTCConnection).mockReturnValue({
start: vi.fn(),
status: 'idle',
stop: vi.fn(),
webRTCConnectionRef: { current: null },
...webRtcConfig,
});

server.use(
http.get('/api/projects/{project_id}/pipeline', ({ response }) =>
response(200).json(getMockedPipeline(pipelineConfig))
)
);

render(
<QueryClientProvider client={new QueryClient()}>
<ZoomProvider>
<MemoryRouter initialEntries={['/projects/123/inspect/stream']}>
<Routes>
<Route path='/projects/:projectId/inspect/stream' element={<StreamContainer />} />
</Routes>
</MemoryRouter>
</ZoomProvider>
</QueryClientProvider>
);
};

describe('Start stream button', () => {
it('call pipeline enable', async () => {
const mockedStart = vi.fn();
const pipelinePatchSpy = vi.fn();

server.use(
http.post('/api/projects/{project_id}/pipeline:enable', () => {
pipelinePatchSpy();
return HttpResponse.json({}, { status: 204 });
})
);

renderApp({ webRtcConfig: { status: 'idle', start: mockedStart }, pipelineConfig: { status: 'idle' } });

const button = await screen.findByRole('button', { name: /Start stream/i });
await userEvent.click(button);

expect(mockedStart).toHaveBeenCalled();
expect(pipelinePatchSpy).toHaveBeenCalled();
});

it('pipeline enable is enabled', async () => {
const mockedStart = vi.fn();
const pipelinePatchSpy = vi.fn();

server.use(
http.post('/api/projects/{project_id}/pipeline:enable', () => {
pipelinePatchSpy();
return HttpResponse.json({}, { status: 204 });
})
);

renderApp({ webRtcConfig: { status: 'idle', start: mockedStart }, pipelineConfig: { status: 'running' } });

const button = await screen.findByRole('button', { name: /Start stream/i });
await userEvent.click(button);

expect(mockedStart).toHaveBeenCalled();
expect(pipelinePatchSpy).not.toHaveBeenCalled();
});
});

it('render loading state', async () => {
renderApp({ webRtcConfig: { status: 'connecting' } });

expect(await screen.findByLabelText('Loading...')).toBeVisible();
});

it('webRtc connected', async () => {
renderApp({ webRtcConfig: { status: 'connected' } });

expect(await screen.findByLabelText('stream player')).toBeVisible();
});

it('autoplay webRtc if pipeline is enabled', async () => {
const mockedStart = vi.fn();
renderApp({ webRtcConfig: { status: 'idle', start: mockedStart }, pipelineConfig: { status: 'running' } });

expect(await screen.findByRole('button', { name: /Start stream/i })).toBeVisible();
expect(mockedStart).toHaveBeenCalled();
});
});
35 changes: 31 additions & 4 deletions application/ui/src/features/inspect/stream/stream-container.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,39 +3,66 @@

import { useEffect, useState } from 'react';

import { useProjectIdentifier } from '@geti-inspect/hooks';
import { Button, Flex, Loading, toast, View } from '@geti/ui';
import { Play } from '@geti/ui/icons';
import { isEmpty } from 'lodash-es';
import { useEnablePipeline, usePipeline } from 'src/hooks/use-pipeline.hook';

import { useWebRTCConnection } from '../../../components/stream/web-rtc-connection-provider';
import { Stream } from './stream';

import classes from '../inference.module.scss';

export const StreamContainer = () => {
const [size, setSize] = useState({ height: 608, width: 892 });
const { projectId } = useProjectIdentifier();
const { data: pipeline } = usePipeline();
const { start, status } = useWebRTCConnection();
const enablePipeline = useEnablePipeline({ onSuccess: start });
const [size, setSize] = useState({ height: 608, width: 892 });

const hasNoSink = isEmpty(pipeline?.sink);
const hasNoSource = isEmpty(pipeline?.source);
const isPipelineRunning = pipeline?.status === 'running';

useEffect(() => {
if (status === 'failed') {
toast({ type: 'error', message: 'Failed to connect to the stream' });
}
}, [status]);

if (isPipelineRunning && status === 'idle') {
start();
}
}, [isPipelineRunning, status, start]);

const handleStart = async () => {
if (isPipelineRunning) {
start();
} else {
enablePipeline.mutate({ params: { path: { project_id: projectId } } });
}
};

return (
<View gridArea={'canvas'} overflow={'hidden'} maxHeight={'100%'}>
{status === 'idle' && (
<div className={classes.canvasContainer}>
<View backgroundColor={'gray-200'} width='90%' height='90%'>
<Flex alignItems={'center'} justifyContent={'center'} height='100%'>
<Button onPress={start} UNSAFE_className={classes.playButton} aria-label={'Start stream'}>
<Button
onPress={handleStart}
aria-label={'Start stream'}
isDisabled={hasNoSink || hasNoSource}
UNSAFE_className={classes.playButton}
>
<Play width='128px' height='128px' />
</Button>
</Flex>
</View>
</div>
)}

{status === 'connecting' && (
{(status === 'connecting' || enablePipeline.isPending) && (
<div className={classes.canvasContainer}>
<View backgroundColor={'gray-200'} width='90%' height='90%'>
<Flex alignItems={'center'} justifyContent={'center'} height='100%'>
Expand Down
45 changes: 19 additions & 26 deletions application/ui/src/features/inspect/stream/stream.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ import { Dispatch, RefObject, SetStateAction, useCallback, useEffect, useRef } f
import { useWebRTCConnection } from '../../../components/stream/web-rtc-connection-provider';
import { ZoomTransform } from '../../../components/zoom/zoom-transform';

interface StreamProps {
size: { width: number; height: number };
setSize: Dispatch<SetStateAction<{ width: number; height: number }>>;
}

const useSetTargetSizeBasedOnVideo = (
setSize: Dispatch<SetStateAction<{ width: number; height: number }>>,
videoRef: RefObject<HTMLVideoElement | null>
Expand Down Expand Up @@ -82,37 +87,25 @@ const useStreamToVideo = () => {
return videoRef;
};

export const Stream = ({
size,
setSize,
}: {
size: { width: number; height: number };
setSize: Dispatch<SetStateAction<{ width: number; height: number }>>;
}) => {
export const Stream = ({ size, setSize }: StreamProps) => {
const videoRef = useStreamToVideo();

useSetTargetSizeBasedOnVideo(setSize, videoRef);

const { status } = useWebRTCConnection();

return (
<ZoomTransform target={size}>
<div style={{ gridArea: 'innercanvas' }}>
{status === 'connected' && (
// eslint-disable-next-line jsx-a11y/media-has-caption
<video
ref={videoRef}
autoPlay
playsInline
width={size.width}
height={size.height}
controls={false}
style={{
background: 'var(--spectrum-global-color-gray-200)',
}}
/>
)}
</div>
<video
ref={videoRef}
muted
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modern browsers don’t allow videos to autoplay unless the user has interacted with the page at least once. Muting the video lets us bypass that restriction.

autoPlay
playsInline
width={size.width}
height={size.height}
controls={false}
aria-label='stream player'
style={{
background: 'var(--spectrum-global-color-gray-200)',
}}
/>
</ZoomTransform>
);
};
Original file line number Diff line number Diff line change
@@ -1,36 +1,29 @@
import { $api } from '@geti-inspect/api';
import { useProjectIdentifier } from '@geti-inspect/hooks';
import { Switch, toast } from '@geti/ui';
import { Flex, Switch, toast } from '@geti/ui';
import { useWebRTCConnection } from 'src/components/stream/web-rtc-connection-provider';
import { useEnablePipeline, usePipeline } from 'src/hooks/use-pipeline.hook';

import { useSelectedMediaItem } from '../../selected-media-item-provider.component';
import { WebRTCConnectionStatus } from './web-rtc-connection-status.component';

import classes from './pipeline-switch.module.scss';

export const PipelineSwitch = () => {
const { projectId } = useProjectIdentifier();
const { status, start, stop } = useWebRTCConnection();
const { onSetSelectedMediaItem } = useSelectedMediaItem();
const { data: pipeline, isLoading } = $api.useSuspenseQuery('get', '/api/projects/{project_id}/pipeline', {
params: { path: { project_id: projectId } },
});
const { data: pipeline, isLoading } = usePipeline();

const isWebRtcConnecting = status === 'connecting';

const enablePipeline = $api.useMutation('post', '/api/projects/{project_id}/pipeline:enable', {
const enablePipeline = useEnablePipeline({
onSuccess: async () => {
await start();
onSetSelectedMediaItem(undefined);
},
onError: (error) => {
if (error) {
toast({ type: 'error', message: String(error.detail) });
}
},
meta: {
invalidates: [
['get', '/api/projects/{project_id}/pipeline', { params: { path: { project_id: projectId } } }],
],
},
});

const isWebRtcConnecting = status === 'connecting';

const disablePipeline = $api.useMutation('post', '/api/projects/{project_id}/pipeline:disable', {
onSuccess: () => stop(),
onError: (error) => {
Expand All @@ -45,18 +38,25 @@ export const PipelineSwitch = () => {
},
});

const hasSink = pipeline?.sink !== undefined;
const hasSource = pipeline?.source !== undefined;

const handleChange = (isSelected: boolean) => {
const handler = isSelected ? enablePipeline.mutate : disablePipeline.mutate;
handler({ params: { path: { project_id: projectId } } });
};

return (
<Switch
onChange={handleChange}
isSelected={pipeline.status === 'running'}
isDisabled={isLoading || isWebRtcConnecting}
>
{isWebRtcConnecting ? 'Connecting...' : 'Enabled'}
</Switch>
<Flex>
<Switch
UNSAFE_className={classes.switch}
onChange={handleChange}
isSelected={pipeline.status === 'running'}
isDisabled={isLoading || isWebRtcConnecting || !hasSink || !hasSource}
>
Enabled
</Switch>
<WebRTCConnectionStatus />
</Flex>
);
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
.switch {
margin: 0px;
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { Flex, StatusLight } from '@geti/ui';
import { Flex, PressableElement, StatusLight, Tooltip, TooltipTrigger } from '@geti/ui';

import { useWebRTCConnection } from './web-rtc-connection-provider';
import { useWebRTCConnection } from '../../../../components/stream/web-rtc-connection-provider';

export const WebRTCConnectionStatus = () => {
const { status } = useWebRTCConnection();
Expand Down Expand Up @@ -47,9 +47,14 @@ export const WebRTCConnectionStatus = () => {
case 'connected':
return (
<Flex gap='size-200' alignItems={'center'}>
<StatusLight role={'status'} aria-label='Connected' variant='positive'>
Connected
</StatusLight>
<TooltipTrigger placement={'top'}>
<PressableElement>
<StatusLight role={'status'} aria-label='Connected' variant='info'>
Connected
</StatusLight>
</PressableElement>
<Tooltip>WebRTC is ready to use</Tooltip>
</TooltipTrigger>
</Flex>
);
}
Expand Down
Loading
Loading