codewithdark commited on
Commit
13a93cb
·
verified ·
1 Parent(s): ca17086

Update utility/image_generator.py

Browse files
Files changed (1) hide show
  1. utility/image_generator.py +30 -19
utility/image_generator.py CHANGED
@@ -20,25 +20,36 @@ def generate_image_prompts(script):
20
 
21
  return prompts
22
 
23
- def hf_pipeline(prompt):
24
- API_URL = "https://api-inference.huggingface.co/models/Shakker-Labs/AWPortrait-FL"
25
- headers = {"Authorization": f"Bearer {os.getenv('HF_API_TOKEN')}"}
26
-
27
- response = requests.post(API_URL, headers=headers, json={"inputs": prompt})
28
-
29
- if response.status_code == 200:
30
- return Image.open(io.BytesIO(response.content)) # Return the image directly
31
- else:
32
- raise Exception(f"Failed to generate image. Status code: {response.status_code}, {response.text}")
 
 
 
 
 
33
 
34
  def generate_images(prompts):
35
- image_files = []
36
- for idx, prompt in enumerate(prompts):
37
- print(f"Generating image for prompt: {prompt}")
38
- # Ensure the prompt is processed on the correct device
39
- image = hf_pipeline(prompt).images[0]
40
- filename = f"generated_image_{idx}.png"
41
- image.save(filename)
42
- image_files.append(filename)
 
43
 
44
- return image_files
 
 
 
 
 
 
20
 
21
  return prompts
22
 
23
+ def hf_pipeline(prompt, max_retries=5, delay=30):
24
+ retries = 0
25
+ while retries < max_retries:
26
+ response = requests.post(f"https://api-inference.huggingface.co/models/Shakker-Labs/AWPortrait-FL",
27
+ json={"inputs": prompt})
28
+ if response.status_code == 503:
29
+ print(f"Model is loading, retrying in {delay} seconds...")
30
+ retries += 1
31
+ time.sleep(delay)
32
+ elif response.status_code == 200:
33
+ return response.json()
34
+ else:
35
+ raise Exception(f"Failed to generate image. Status code: {response.status_code}, {response.text}")
36
+
37
+ raise Exception(f"Failed to generate image after {max_retries} retries.")
38
 
39
  def generate_images(prompts):
40
+ try:
41
+ image_files = []
42
+ for idx, prompt in enumerate(prompts):
43
+ print(f"Generating image for prompt: {prompt}")
44
+ # Ensure the prompt is processed on the correct device
45
+ image = hf_pipeline(prompt).images[0]
46
+ filename = f"generated_image_{idx}.png"
47
+ image.save(filename)
48
+ image_files.append(filename)
49
 
50
+ return image_files
51
+ except Exception as e:
52
+ print(e)
53
+
54
+
55
+