harmionestark commited on
Commit
9cec643
Β·
verified Β·
1 Parent(s): 37d6c46

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -17
app.py CHANGED
@@ -1,29 +1,33 @@
 
1
  from diffusers import StableDiffusionPipeline
2
  import torch
3
- from flask import Flask, request, jsonify
 
4
 
5
  app = Flask(__name__)
6
 
7
  # Load the model
8
  model_id = "kothariyashhh/GenAi-Texttoimage"
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
-
11
- # Use float32 if running on CPU
12
- torch_dtype = torch.float16 if device == "cuda" else torch.float32
13
-
14
- pipeline = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch_dtype)
15
- pipeline.to(device)
16
 
17
  @app.route("/generate", methods=["POST"])
18
  def generate_image():
19
- data = request.get_json()
20
- prompt = data.get("prompt", "A scenic landscape")
21
-
22
- image = pipeline(prompt).images[0]
23
- image_path = "generated_image.png"
24
- image.save(image_path)
25
-
26
- return jsonify({"image_url": image_path})
 
 
 
 
 
 
 
27
 
28
  if __name__ == "__main__":
29
- app.run(host="0.0.0.0", port=7860)
 
1
+ from flask import Flask, request, jsonify, send_file
2
  from diffusers import StableDiffusionPipeline
3
  import torch
4
+ from io import BytesIO
5
+ from PIL import Image
6
 
7
  app = Flask(__name__)
8
 
9
  # Load the model
10
  model_id = "kothariyashhh/GenAi-Texttoimage"
11
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
12
+ pipe = pipe.to("cuda") # Ensure you're using a GPU for inference
 
 
 
 
 
13
 
14
  @app.route("/generate", methods=["POST"])
15
  def generate_image():
16
+ data = request.json
17
+ prompt = data.get("prompt")
18
+ if not prompt:
19
+ return jsonify({"error": "Prompt is required"}), 400
20
+
21
+ # Generate image
22
+ with torch.autocast("cuda"):
23
+ image = pipe(prompt).images[0]
24
+
25
+ # Save image to a BytesIO object
26
+ img_io = BytesIO()
27
+ image.save(img_io, format="PNG")
28
+ img_io.seek(0)
29
+
30
+ return send_file(img_io, mimetype="image/png")
31
 
32
  if __name__ == "__main__":
33
+ app.run(host="0.0.0.0", port=5000)