adirik's picture
Update app.py
b1dbb4f verified
import os
import io
import requests
from uuid import uuid4
import boto3
from botocore.client import Config
from PIL import Image
import gradio as gr
css = """
#col-container {
margin: 0 auto;
max-width: 512px;
}
"""
MAX_PIXEL_BUDGET = 1024 * 1024
def upload_to_r2(file_path, object_key, content_type):
s3 = boto3.client(
's3',
endpoint_url=os.getenv('R2_ENDPOINT'),
aws_access_key_id=os.getenv('R2_ACCESS_KEY_ID'),
aws_secret_access_key=os.getenv('R2_SECRET_ACCESS_KEY'),
config=Config(signature_version='s3v4'),
region_name='auto'
)
with open(file_path, 'rb') as f:
s3.put_object(
Bucket=os.getenv('R2_BUCKET'),
Key=object_key,
Body=f,
ContentType=content_type
)
download_url = s3.generate_presigned_url(
'get_object',
Params={
'Bucket': os.getenv('R2_BUCKET'),
'Key': object_key
},
ExpiresIn=3600 # url expiration time in seconds
)
return download_url
def process_input(input_image, upscale_factor):
w, h = input_image.size
w_original, h_original = w, h
aspect_ratio = w / h
was_resized = False
# compute minimum dimension after upscaling
min_dimension = min(w, h) * upscale_factor
# if minimum dimension is above 1024, adjust scale factor
if min_dimension > 1024:
new_scale = 1024 / min(w, h)
upscale_factor = min(2, new_scale) # cap at 2x if needed
gr.Info(f'Adjusted scale factor to {upscale_factor}x to maintain minimum dimension of 1024 pixels')
if w * h * upscale_factor**2 > MAX_PIXEL_BUDGET:
gr.Info(
f'Requested output image is too large. Resizing input to fit within pixel budget.'
)
input_image = input_image.resize(
(
int(aspect_ratio * MAX_PIXEL_BUDGET**0.5 // upscale_factor),
int(MAX_PIXEL_BUDGET**0.5 // aspect_ratio // upscale_factor),
)
)
was_resized = True
return input_image, w_original, h_original, was_resized, upscale_factor
def infer(
input_image,
upscale_factor,
image_category,
):
true_input_image = input_image
input_image, w_original, h_original, was_resized, adjusted_scale = process_input(
input_image, upscale_factor
)
temp_input_path = 'temp_input.png'
input_image.save(temp_input_path)
image_url = upload_to_r2(temp_input_path, str(uuid4()), 'image/png')
gr.Info('Upscaling image...')
try:
resp = requests.get(
os.getenv('ENDPOINT') ,
headers={
'Modal-Key': os.getenv('AUTH_KEY'),
'Modal-Secret': os.getenv('AUTH_SECRET'),
},
params={
'image_url': image_url,
'image_category': image_category,
'scale_factor': adjusted_scale,
'output_format': 'png',
'upload_to_r2': False
}
)
if resp.status_code != 200:
raise gr.Error(f'API request failed with status {resp.status_code}: {resp.text}')
# save the response image
output_path = 'output.png'
with open(output_path, 'wb') as f:
f.write(resp.content)
output_image = Image.open(output_path)
if was_resized:
gr.Info(
f'Resizing output image to targeted {w_original * adjusted_scale}x{h_original * adjusted_scale} size.'
)
output_image = output_image.resize((int(w_original * adjusted_scale), int(h_original * adjusted_scale)))
return output_image
except Exception as e:
raise gr.Error(f'Error during upscaling: {str(e)}')
finally:
if os.path.exists(temp_input_path):
os.remove(temp_input_path)
if os.path.exists(output_path):
os.remove(output_path)
with gr.Blocks(css=css) as demo:
gr.Markdown(
"""
# πŸš€ Puncta Lite - AI Image Upscaler
This is a lite version of Puncta's AI image upscaler. For more advanced features and higher quality results, visit [puncta.ai](https://www.puncta.ai/).
*Note*: This demo is limited to a maximum output resolution of 1024x1024 pixels. For higher resolution upscaling, please visit our full version at puncta.ai.
"""
)
with gr.Row():
with gr.Column(scale=1):
input_im = gr.Image(label='Input Image', type='pil')
with gr.Column(scale=1):
output_im = gr.Image(label='Output Image', type='pil')
with gr.Row():
with gr.Column(scale=1):
upscale_factor = gr.Slider(
label='Upscale Factor',
minimum=2,
maximum=4,
step=1,
value=2,
)
with gr.Column(scale=1):
image_category = gr.Dropdown(
label='Image Category',
choices=['general', 'portrait', 'outdoor', 'digital art'],
value='general'
)
with gr.Row():
run_button = gr.Button(value='Upscale Image')
examples = gr.Examples(
examples=[
['examples/dogs.jpg', 2, 'outdoor'],
['examples/portrait1.png', 3, 'portrait'],
['examples/anime_1.png', 2, 'digital art'],
['examples/vintage_family_photo.jpg', 2, 'portrait'],
['examples/bus1.png', 3, 'general'],
['examples/festival_3.png', 2, 'general'],
['examples/general_2.png', 2, 'general'],
['examples/anime_2.jpg', 3, 'digital art'],
],
inputs=[
input_im,
upscale_factor,
image_category,
],
fn=infer,
outputs=output_im,
cache_examples=True,
)
gr.Markdown("**Disclaimer:**")
gr.Markdown(
"This demo is for testing purposes only. For commercial use and higher quality results, please visit [puncta.ai](https://www.puncta.ai/)."
)
gr.on(
[run_button.click],
fn=infer,
inputs=[
input_im,
upscale_factor,
image_category,
],
outputs=output_im,
show_api=False,
)
demo.queue().launch(share=False, show_api=False)