Spaces:
Running
Running
import os | |
import io | |
import requests | |
from uuid import uuid4 | |
import boto3 | |
from botocore.client import Config | |
from PIL import Image | |
import gradio as gr | |
css = """ | |
#col-container { | |
margin: 0 auto; | |
max-width: 512px; | |
} | |
""" | |
MAX_PIXEL_BUDGET = 1024 * 1024 | |
def upload_to_r2(file_path, object_key, content_type): | |
s3 = boto3.client( | |
's3', | |
endpoint_url=os.getenv('R2_ENDPOINT'), | |
aws_access_key_id=os.getenv('R2_ACCESS_KEY_ID'), | |
aws_secret_access_key=os.getenv('R2_SECRET_ACCESS_KEY'), | |
config=Config(signature_version='s3v4'), | |
region_name='auto' | |
) | |
with open(file_path, 'rb') as f: | |
s3.put_object( | |
Bucket=os.getenv('R2_BUCKET'), | |
Key=object_key, | |
Body=f, | |
ContentType=content_type | |
) | |
download_url = s3.generate_presigned_url( | |
'get_object', | |
Params={ | |
'Bucket': os.getenv('R2_BUCKET'), | |
'Key': object_key | |
}, | |
ExpiresIn=3600 # url expiration time in seconds | |
) | |
return download_url | |
def process_input(input_image, upscale_factor): | |
w, h = input_image.size | |
w_original, h_original = w, h | |
aspect_ratio = w / h | |
was_resized = False | |
# compute minimum dimension after upscaling | |
min_dimension = min(w, h) * upscale_factor | |
# if minimum dimension is above 1024, adjust scale factor | |
if min_dimension > 1024: | |
new_scale = 1024 / min(w, h) | |
upscale_factor = min(2, new_scale) # cap at 2x if needed | |
gr.Info(f'Adjusted scale factor to {upscale_factor}x to maintain minimum dimension of 1024 pixels') | |
if w * h * upscale_factor**2 > MAX_PIXEL_BUDGET: | |
gr.Info( | |
f'Requested output image is too large. Resizing input to fit within pixel budget.' | |
) | |
input_image = input_image.resize( | |
( | |
int(aspect_ratio * MAX_PIXEL_BUDGET**0.5 // upscale_factor), | |
int(MAX_PIXEL_BUDGET**0.5 // aspect_ratio // upscale_factor), | |
) | |
) | |
was_resized = True | |
return input_image, w_original, h_original, was_resized, upscale_factor | |
def infer( | |
input_image, | |
upscale_factor, | |
image_category, | |
): | |
true_input_image = input_image | |
input_image, w_original, h_original, was_resized, adjusted_scale = process_input( | |
input_image, upscale_factor | |
) | |
temp_input_path = 'temp_input.png' | |
input_image.save(temp_input_path) | |
image_url = upload_to_r2(temp_input_path, str(uuid4()), 'image/png') | |
gr.Info('Upscaling image...') | |
try: | |
resp = requests.get( | |
os.getenv('ENDPOINT') , | |
headers={ | |
'Modal-Key': os.getenv('AUTH_KEY'), | |
'Modal-Secret': os.getenv('AUTH_SECRET'), | |
}, | |
params={ | |
'image_url': image_url, | |
'image_category': image_category, | |
'scale_factor': adjusted_scale, | |
'output_format': 'png', | |
'upload_to_r2': False | |
} | |
) | |
if resp.status_code != 200: | |
raise gr.Error(f'API request failed with status {resp.status_code}: {resp.text}') | |
# save the response image | |
output_path = 'output.png' | |
with open(output_path, 'wb') as f: | |
f.write(resp.content) | |
output_image = Image.open(output_path) | |
if was_resized: | |
gr.Info( | |
f'Resizing output image to targeted {w_original * adjusted_scale}x{h_original * adjusted_scale} size.' | |
) | |
output_image = output_image.resize((int(w_original * adjusted_scale), int(h_original * adjusted_scale))) | |
return output_image | |
except Exception as e: | |
raise gr.Error(f'Error during upscaling: {str(e)}') | |
finally: | |
if os.path.exists(temp_input_path): | |
os.remove(temp_input_path) | |
if os.path.exists(output_path): | |
os.remove(output_path) | |
with gr.Blocks(css=css) as demo: | |
gr.Markdown( | |
""" | |
# π Puncta Lite - AI Image Upscaler | |
This is a lite version of Puncta's AI image upscaler. For more advanced features and higher quality results, visit [puncta.ai](https://www.puncta.ai/). | |
*Note*: This demo is limited to a maximum output resolution of 1024x1024 pixels. For higher resolution upscaling, please visit our full version at puncta.ai. | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
input_im = gr.Image(label='Input Image', type='pil') | |
with gr.Column(scale=1): | |
output_im = gr.Image(label='Output Image', type='pil') | |
with gr.Row(): | |
with gr.Column(scale=1): | |
upscale_factor = gr.Slider( | |
label='Upscale Factor', | |
minimum=2, | |
maximum=4, | |
step=1, | |
value=2, | |
) | |
with gr.Column(scale=1): | |
image_category = gr.Dropdown( | |
label='Image Category', | |
choices=['general', 'portrait', 'outdoor', 'digital art'], | |
value='general' | |
) | |
with gr.Row(): | |
run_button = gr.Button(value='Upscale Image') | |
examples = gr.Examples( | |
examples=[ | |
['examples/dogs.jpg', 2, 'outdoor'], | |
['examples/portrait1.png', 3, 'portrait'], | |
['examples/anime_1.png', 2, 'digital art'], | |
['examples/vintage_family_photo.jpg', 2, 'portrait'], | |
['examples/bus1.png', 3, 'general'], | |
['examples/festival_3.png', 2, 'general'], | |
['examples/general_2.png', 2, 'general'], | |
['examples/anime_2.jpg', 3, 'digital art'], | |
], | |
inputs=[ | |
input_im, | |
upscale_factor, | |
image_category, | |
], | |
fn=infer, | |
outputs=output_im, | |
cache_examples=True, | |
) | |
gr.Markdown("**Disclaimer:**") | |
gr.Markdown( | |
"This demo is for testing purposes only. For commercial use and higher quality results, please visit [puncta.ai](https://www.puncta.ai/)." | |
) | |
gr.on( | |
[run_button.click], | |
fn=infer, | |
inputs=[ | |
input_im, | |
upscale_factor, | |
image_category, | |
], | |
outputs=output_im, | |
show_api=False, | |
) | |
demo.queue().launch(share=False, show_api=False) |