jbilcke-hf HF staff commited on
Commit
cb7d06b
·
1 Parent(s): 4872066

adding captionning step

Browse files
src/index.mts CHANGED
@@ -182,6 +182,7 @@ app.post("/render", async (req, res) => {
182
  renderId: "",
183
  status: "pending",
184
  assetUrl: "",
 
185
  maskUrl: "",
186
  error: "",
187
  segments: []
@@ -246,6 +247,7 @@ app.get("/render/:renderId", async (req, res) => {
246
  renderId: "",
247
  status: "pending",
248
  assetUrl: "",
 
249
  error: "",
250
  maskUrl: "",
251
  segments: []
 
182
  renderId: "",
183
  status: "pending",
184
  assetUrl: "",
185
+ alt: request.prompt || "",
186
  maskUrl: "",
187
  error: "",
188
  segments: []
 
247
  renderId: "",
248
  status: "pending",
249
  assetUrl: "",
250
+ alt: "",
251
  error: "",
252
  maskUrl: "",
253
  segments: []
src/production/renderImageAnalysis.mts ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { analyzeImage } from "../analysis/analyzeImageWithIDEFICSAndNastyHack.mts"
2
+ import { RenderedScene, RenderRequest } from "../types.mts"
3
+ import { upscaleImage } from "../utils/upscaleImage.mts"
4
+
5
+ export async function renderImageAnalysis(
6
+ request: RenderRequest,
7
+ response: RenderedScene,
8
+ ): Promise<RenderedScene> {
9
+ response.alt = request.prompt
10
+
11
+ try {
12
+ // note: this converts a base64 PNG to a base64 JPG (which is good, actually!)
13
+ response.assetUrl = await analyzeImage(response.assetUrl, response.assetUrl)
14
+ console.log(`analysis worked on the first try!`)
15
+ } catch (err) {
16
+ console.error(`analysis failed the first time.. let's try again..`)
17
+ try {
18
+ response.assetUrl = await upscaleImage(response.assetUrl, request.upscalingFactor)
19
+ console.log(`analysis worked on the second try!`)
20
+ } catch (err) {
21
+ console.error(`analysis failed on the second attempt.. let's keep the prompt as a fallback, then :|`)
22
+ // no need to log a catastrophic failure here, since we still have the original (low-res image)
23
+ // to work with
24
+ response.alt = request.prompt
25
+ }
26
+ }
27
+
28
+ return response
29
+ }
src/production/renderPipeline.mts CHANGED
@@ -5,9 +5,9 @@ import { renderImage } from "./renderImage.mts"
5
  import { renderVideo } from "./renderVideo.mts"
6
  import { renderImageSegmentation } from "./renderImageSegmentation.mts"
7
  import { renderVideoSegmentation } from "./renderVideoSegmentation.mts"
8
- import { upscaleImage } from "../utils/upscaleImage.mts"
9
  import { renderImageUpscaling } from "./renderImageUpscaling.mts"
10
  import { saveRenderedSceneToCache } from "../utils/saveRenderedSceneToCache.mts"
 
11
 
12
  export async function renderPipeline(request: RenderRequest, response: RenderedScene) {
13
  const isVideo = request?.nbFrames > 1
@@ -40,8 +40,13 @@ export async function renderPipeline(request: RenderRequest, response: RenderedS
40
  ? Promise.resolve()
41
  : renderImageUpscaling(request, response)
42
 
 
 
 
 
43
  await Promise.all([
44
  renderSegmentation(request, response),
 
45
  optionalUpscalingStep
46
  ])
47
 
 
5
  import { renderVideo } from "./renderVideo.mts"
6
  import { renderImageSegmentation } from "./renderImageSegmentation.mts"
7
  import { renderVideoSegmentation } from "./renderVideoSegmentation.mts"
 
8
  import { renderImageUpscaling } from "./renderImageUpscaling.mts"
9
  import { saveRenderedSceneToCache } from "../utils/saveRenderedSceneToCache.mts"
10
+ import { renderImageAnalysis } from "./renderImageAnalysis.mts"
11
 
12
  export async function renderPipeline(request: RenderRequest, response: RenderedScene) {
13
  const isVideo = request?.nbFrames > 1
 
40
  ? Promise.resolve()
41
  : renderImageUpscaling(request, response)
42
 
43
+ const optionalAnalysisStep = request.analyze
44
+ ? renderImageAnalysis(request, response)
45
+ : Promise.resolve()
46
+
47
  await Promise.all([
48
  renderSegmentation(request, response),
49
+ optionalAnalysisStep,
50
  optionalUpscalingStep
51
  ])
52
 
src/production/renderScene.mts CHANGED
@@ -1,10 +1,7 @@
1
  import { v4 as uuidv4 } from "uuid"
2
 
3
  import { RenderedScene, RenderRequest } from "../types.mts"
4
- import { generateSeed } from "../utils/generateSeed.mts"
5
- import { getValidNumber } from "../utils/getValidNumber.mts"
6
  import { renderPipeline } from "./renderPipeline.mts"
7
- import { getValidBoolean } from "../utils/getValidBoolean.mts"
8
 
9
  const cache: Record<string, RenderedScene> = {}
10
  const cacheQueue: string[] = []
@@ -19,6 +16,7 @@ export async function renderScene(request: RenderRequest): Promise<RenderedScene
19
  renderId,
20
  status: "pending",
21
  assetUrl: "",
 
22
  error: "",
23
  maskUrl: "",
24
  segments: []
 
1
  import { v4 as uuidv4 } from "uuid"
2
 
3
  import { RenderedScene, RenderRequest } from "../types.mts"
 
 
4
  import { renderPipeline } from "./renderPipeline.mts"
 
5
 
6
  const cache: Record<string, RenderedScene> = {}
7
  const cacheQueue: string[] = []
 
16
  renderId,
17
  status: "pending",
18
  assetUrl: "",
19
+ alt: request.prompt || "",
20
  error: "",
21
  maskUrl: "",
22
  segments: []
src/types.mts CHANGED
@@ -316,6 +316,8 @@ export interface RenderRequest {
316
  cache: CacheMode
317
 
318
  wait: boolean // wait until the job is completed
 
 
319
  }
320
 
321
  export interface ImageAnalysisRequest {
@@ -358,6 +360,7 @@ export interface RenderedScene {
358
  renderId: string
359
  status: RenderedSceneStatus
360
  assetUrl: string
 
361
  error: string
362
  maskUrl: string
363
  segments: ImageSegment[]
 
316
  cache: CacheMode
317
 
318
  wait: boolean // wait until the job is completed
319
+
320
+ analyze: boolean // analyze the image to generate a caption (optional)
321
  }
322
 
323
  export interface ImageAnalysisRequest {
 
360
  renderId: string
361
  status: RenderedSceneStatus
362
  assetUrl: string
363
+ alt: string
364
  error: string
365
  maskUrl: string
366
  segments: ImageSegment[]
src/utils/parseRenderRequest.mts CHANGED
@@ -24,6 +24,8 @@ export function parseRenderRequest(request: RenderRequest) {
24
 
25
  request.nbSteps = getValidNumber(request.nbSteps, 5, 50, 10)
26
 
 
 
27
  if (isVideo) {
28
  request.width = getValidNumber(request.width, 256, 1024, 1024)
29
  request.height = getValidNumber(request.height, 256, 1024, 512)
 
24
 
25
  request.nbSteps = getValidNumber(request.nbSteps, 5, 50, 10)
26
 
27
+ request.analyze = request?.analyze ? true : false
28
+
29
  if (isVideo) {
30
  request.width = getValidNumber(request.width, 256, 1024, 1024)
31
  request.height = getValidNumber(request.height, 256, 1024, 512)