Update app.py
Browse files
    	
        app.py
    CHANGED
    
    | @@ -7,14 +7,18 @@ import time | |
| 7 | 
             
            import gradio as gr
         | 
| 8 | 
             
            from PIL import Image
         | 
| 9 | 
             
            from io import BytesIO
         | 
|  | |
| 10 |  | 
|  | |
| 11 | 
             
            host = "http://18.119.36.46:8888"
         | 
| 12 |  | 
| 13 | 
             
            def image_prompt(prompt, image1, image2, image3, image4):
         | 
| 14 | 
             
                try:
         | 
|  | |
| 15 | 
             
                    image_sources = [open(image, "rb").read() for image in [image1, image2, image3, image4]]
         | 
| 16 | 
             
                    encoded_images = [base64.b64encode(img).decode('utf-8') for img in image_sources]
         | 
| 17 |  | 
|  | |
| 18 | 
             
                    params = {
         | 
| 19 | 
             
                        "prompt": prompt,
         | 
| 20 | 
             
                        "image_prompts": [
         | 
| @@ -28,17 +32,19 @@ def image_prompt(prompt, image1, image2, image3, image4): | |
| 28 | 
             
                        "async_process": True
         | 
| 29 | 
             
                    }
         | 
| 30 |  | 
|  | |
| 31 | 
             
                    session = requests.Session()
         | 
| 32 | 
             
                    retries = Retry(total=5, backoff_factor=1, status_forcelist=[502, 503, 504])
         | 
| 33 | 
             
                    session.mount('http://', HTTPAdapter(max_retries=retries))
         | 
| 34 |  | 
|  | |
| 35 | 
             
                    response = session.post(
         | 
| 36 | 
             
                        url=f"{host}/v2/generation/text-to-image-with-ip",
         | 
| 37 | 
             
                        data=json.dumps(params),
         | 
| 38 | 
             
                        headers={"Content-Type": "application/json"},
         | 
| 39 | 
            -
                        timeout=10
         | 
| 40 | 
             
                    )
         | 
| 41 | 
            -
                    response.raise_for_status()
         | 
| 42 | 
             
                    result = response.json()
         | 
| 43 | 
             
                    job_id = result.get('job_id')
         | 
| 44 |  | 
| @@ -46,13 +52,14 @@ def image_prompt(prompt, image1, image2, image3, image4): | |
| 46 | 
             
                        while True:
         | 
| 47 | 
             
                            query_url = f"{host}/v1/generation/query-job?job_id={job_id}&require_step_preview=true"
         | 
| 48 | 
             
                            response = session.get(query_url, timeout=10)
         | 
| 49 | 
            -
                            response.raise_for_status()
         | 
| 50 | 
             
                            job_data = response.json()
         | 
| 51 |  | 
| 52 | 
             
                            job_stage = job_data.get("job_stage")
         | 
| 53 | 
             
                            job_step_preview = job_data.get("job_step_preview")
         | 
| 54 | 
             
                            job_result = job_data.get("job_result")
         | 
| 55 |  | 
|  | |
| 56 | 
             
                            if job_stage == "RUNNING" and job_step_preview:
         | 
| 57 | 
             
                                image = Image.open(BytesIO(base64.b64decode(job_step_preview)))
         | 
| 58 | 
             
                                yield image, f"Job is running smoothly. Current stage: {job_stage}. Hang tight!"
         | 
| @@ -73,7 +80,7 @@ def image_prompt(prompt, image1, image2, image3, image4): | |
| 73 | 
             
                                yield None, "Job failed. Let's check the parameters and try again."
         | 
| 74 | 
             
                                break
         | 
| 75 |  | 
| 76 | 
            -
                            time.sleep(2)
         | 
| 77 |  | 
| 78 | 
             
                    else:
         | 
| 79 | 
             
                        yield None, "Job ID not found. Did we miss something in the setup?"
         | 
|  | |
| 7 | 
             
            import gradio as gr
         | 
| 8 | 
             
            from PIL import Image
         | 
| 9 | 
             
            from io import BytesIO
         | 
| 10 | 
            +
            import os
         | 
| 11 |  | 
| 12 | 
            +
            # Your host address - this will need to scale in the future
         | 
| 13 | 
             
            host = "http://18.119.36.46:8888"
         | 
| 14 |  | 
| 15 | 
             
            def image_prompt(prompt, image1, image2, image3, image4):
         | 
| 16 | 
             
                try:
         | 
| 17 | 
            +
                    # Reading image files and encoding them
         | 
| 18 | 
             
                    image_sources = [open(image, "rb").read() for image in [image1, image2, image3, image4]]
         | 
| 19 | 
             
                    encoded_images = [base64.b64encode(img).decode('utf-8') for img in image_sources]
         | 
| 20 |  | 
| 21 | 
            +
                    # Prepare the payload with all the image prompts
         | 
| 22 | 
             
                    params = {
         | 
| 23 | 
             
                        "prompt": prompt,
         | 
| 24 | 
             
                        "image_prompts": [
         | 
|  | |
| 32 | 
             
                        "async_process": True
         | 
| 33 | 
             
                    }
         | 
| 34 |  | 
| 35 | 
            +
                    # Setup retry strategy for robust request handling
         | 
| 36 | 
             
                    session = requests.Session()
         | 
| 37 | 
             
                    retries = Retry(total=5, backoff_factor=1, status_forcelist=[502, 503, 504])
         | 
| 38 | 
             
                    session.mount('http://', HTTPAdapter(max_retries=retries))
         | 
| 39 |  | 
| 40 | 
            +
                    # Initiating the job
         | 
| 41 | 
             
                    response = session.post(
         | 
| 42 | 
             
                        url=f"{host}/v2/generation/text-to-image-with-ip",
         | 
| 43 | 
             
                        data=json.dumps(params),
         | 
| 44 | 
             
                        headers={"Content-Type": "application/json"},
         | 
| 45 | 
            +
                        timeout=10  # Timeout can be adjusted as needed
         | 
| 46 | 
             
                    )
         | 
| 47 | 
            +
                    response.raise_for_status()  # Ensure we catch any HTTP errors
         | 
| 48 | 
             
                    result = response.json()
         | 
| 49 | 
             
                    job_id = result.get('job_id')
         | 
| 50 |  | 
|  | |
| 52 | 
             
                        while True:
         | 
| 53 | 
             
                            query_url = f"{host}/v1/generation/query-job?job_id={job_id}&require_step_preview=true"
         | 
| 54 | 
             
                            response = session.get(query_url, timeout=10)
         | 
| 55 | 
            +
                            response.raise_for_status()  # Catch any issues with querying the job
         | 
| 56 | 
             
                            job_data = response.json()
         | 
| 57 |  | 
| 58 | 
             
                            job_stage = job_data.get("job_stage")
         | 
| 59 | 
             
                            job_step_preview = job_data.get("job_step_preview")
         | 
| 60 | 
             
                            job_result = job_data.get("job_result")
         | 
| 61 |  | 
| 62 | 
            +
                            # Real-time update to image and status
         | 
| 63 | 
             
                            if job_stage == "RUNNING" and job_step_preview:
         | 
| 64 | 
             
                                image = Image.open(BytesIO(base64.b64decode(job_step_preview)))
         | 
| 65 | 
             
                                yield image, f"Job is running smoothly. Current stage: {job_stage}. Hang tight!"
         | 
|  | |
| 80 | 
             
                                yield None, "Job failed. Let's check the parameters and try again."
         | 
| 81 | 
             
                                break
         | 
| 82 |  | 
| 83 | 
            +
                            time.sleep(2)  # Pause for 2 seconds before checking again
         | 
| 84 |  | 
| 85 | 
             
                    else:
         | 
| 86 | 
             
                        yield None, "Job ID not found. Did we miss something in the setup?"
         | 
