t3k45h1 commited on
Commit
a5f0152
Β·
verified Β·
1 Parent(s): 326e1e4

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +156 -0
main.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Form, HTTPException
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from fastapi.responses import JSONResponse, FileResponse
4
+ from diffusers import StableDiffusionPipeline
5
+ import torch
6
+ import uuid
7
+ import base64
8
+ import io
9
+ from PIL import Image
10
+ import os
11
+
12
+ # Initialize FastAPI app
13
+ app = FastAPI(title="PromptAgro Image Generator API")
14
+
15
+ # Add CORS middleware to allow frontend connections
16
+ app.add_middleware(
17
+ CORSMiddleware,
18
+ allow_origins=["*"], # In production, specify your frontend domains
19
+ allow_credentials=True,
20
+ allow_methods=["*"],
21
+ allow_headers=["*"],
22
+ )
23
+
24
+ # Load Stable Diffusion LCM model (your original approach)
25
+ print("πŸš€ Loading Stable Diffusion Model...")
26
+ model_id = "rupeshs/LCM-runwayml-stable-diffusion-v1-5"
27
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
28
+ pipe = pipe.to("cuda" if torch.cuda.is_available() else "cpu")
29
+ print("βœ… Model Loaded.")
30
+
31
+ @app.get("/")
32
+ async def root():
33
+ """Health check endpoint"""
34
+ return {
35
+ "status": "alive",
36
+ "service": "PromptAgro Image Generator",
37
+ "model_loaded": True,
38
+ "device": "cuda" if torch.cuda.is_available() else "cpu"
39
+ }
40
+
41
+ @app.post("/generate/")
42
+ async def generate_image(prompt: str = Form(...)):
43
+ """
44
+ Generate product packaging image from input prompt.
45
+ Returns image file directly (your original approach).
46
+ """
47
+ print(f"πŸ–ŒοΈ Generating image for prompt: {prompt}")
48
+
49
+ # Generate image (your original approach)
50
+ image = pipe(prompt).images[0]
51
+
52
+ # Save image to temp file (your original approach)
53
+ filename = f"/tmp/{uuid.uuid4().hex}.png"
54
+ image.save(filename)
55
+
56
+ print(f"πŸ“¦ Image saved to {filename}")
57
+
58
+ # Return image file as response (your original approach)
59
+ return FileResponse(filename, media_type="image/png")
60
+
61
+ @app.post("/generate-json/")
62
+ async def generate_image_json(
63
+ prompt: str = Form(...),
64
+ width: int = Form(512),
65
+ height: int = Form(512),
66
+ num_inference_steps: int = Form(4), # LCM works well with few steps
67
+ guidance_scale: float = Form(1.0) # LCM uses low guidance
68
+ ):
69
+ """
70
+ Generate image and return as JSON with base64 data (for frontend integration).
71
+ """
72
+ print(f"πŸ–ŒοΈ Generating image for prompt: {prompt}")
73
+
74
+ try:
75
+ # Generate image with parameters optimized for LCM
76
+ image = pipe(
77
+ prompt=prompt,
78
+ width=width,
79
+ height=height,
80
+ num_inference_steps=num_inference_steps,
81
+ guidance_scale=guidance_scale
82
+ ).images[0]
83
+
84
+ # Convert image to base64 for JSON response
85
+ buffer = io.BytesIO()
86
+ image.save(buffer, format='PNG')
87
+ img_str = base64.b64encode(buffer.getvalue()).decode()
88
+
89
+ print("βœ… Image generated successfully")
90
+
91
+ return JSONResponse({
92
+ "success": True,
93
+ "image_data": f"data:image/png;base64,{img_str}",
94
+ "prompt_used": prompt,
95
+ "dimensions": {"width": width, "height": height},
96
+ "steps": num_inference_steps
97
+ })
98
+
99
+ except Exception as e:
100
+ print(f"❌ Generation failed: {e}")
101
+ raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
102
+
103
+ @app.post("/generate-packaging/")
104
+ async def generate_packaging_specific(
105
+ product_name: str = Form(...),
106
+ colors: str = Form("green,yellow"),
107
+ emotion: str = Form("trust"),
108
+ platform: str = Form("farmers-market")
109
+ ):
110
+ """
111
+ Generate packaging with PromptAgro-specific prompt engineering
112
+ """
113
+ # Create professional prompt for agricultural packaging
114
+ prompt = f"""Professional agricultural product packaging design for {product_name},
115
+ modern clean style, {colors.replace(',', ' and ')} color scheme, premium typography,
116
+ conveying {emotion}, suitable for {platform}, product photography style,
117
+ white background, high quality commercial design, realistic packaging mockup,
118
+ professional studio lighting, eco-friendly agricultural branding"""
119
+
120
+ prompt = prompt.strip().replace('\n', ' ').replace(' ', ' ')
121
+
122
+ print(f"🎨 Generating packaging for: {product_name}")
123
+ print(f"πŸ“ Using prompt: {prompt}")
124
+
125
+ try:
126
+ # Generate with packaging-optimized settings
127
+ image = pipe(
128
+ prompt=prompt,
129
+ width=768,
130
+ height=768,
131
+ num_inference_steps=6,
132
+ guidance_scale=1.5
133
+ ).images[0]
134
+
135
+ # Convert to base64
136
+ buffer = io.BytesIO()
137
+ image.save(buffer, format='PNG')
138
+ img_str = base64.b64encode(buffer.getvalue()).decode()
139
+
140
+ return JSONResponse({
141
+ "success": True,
142
+ "image_data": f"data:image/png;base64,{img_str}",
143
+ "prompt_used": prompt,
144
+ "product_name": product_name,
145
+ "generator": "Stable Diffusion LCM",
146
+ "cost": "FREE",
147
+ "processing_time": "~3-5 seconds"
148
+ })
149
+
150
+ except Exception as e:
151
+ print(f"❌ Packaging generation failed: {e}")
152
+ raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
153
+
154
+ if __name__ == "__main__":
155
+ import uvicorn
156
+ uvicorn.run(app, host="0.0.0.0", port=7860)