Shabbir-Anjum commited on
Commit
110b7d6
·
verified ·
1 Parent(s): 349bb4d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -17
app.py CHANGED
@@ -1,26 +1,27 @@
1
- import torch
2
  import streamlit as st
3
- from diffusers import StableDiffusion3Pipeline
4
 
5
- # Retry mechanism for loading the model
6
- def retry_load_model():
 
7
  try:
8
- return StableDiffusion3Pipeline.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16)
9
- except EnvironmentError as e:
 
10
  st.error(f"Error loading model: {e}")
11
- st.warning("Retrying to load the model...")
12
- retry_load_model()
13
-
14
- # Load the Diffusion pipeline
15
- pipeline = retry_load_model()
16
 
17
- def generate_prompt(prompt_text):
18
- # Generate response using the Diffusion model
19
- response = pipeline(prompt_text, top_p=0.9, max_length=100)[0]['generated_text']
20
- return response
 
 
21
 
22
  def main():
23
- st.title('Diffusion Model Prompt Generator')
 
 
 
24
 
25
  # Text input for the prompt
26
  prompt_text = st.text_area("Enter your prompt here:", height=200)
@@ -29,7 +30,7 @@ def main():
29
  if st.button("Generate"):
30
  if prompt_text:
31
  with st.spinner('Generating...'):
32
- generated_text = generate_prompt(prompt_text)
33
  st.success('Generation complete!')
34
  st.text_area('Generated Text:', value=generated_text, height=400)
35
  else:
 
 
1
  import streamlit as st
2
+ from diffusers import DiffusionPipeline
3
 
4
+ # Load the Diffusion pipeline
5
+ @st.cache(allow_output_mutation=True)
6
+ def load_diffusion_pipeline():
7
  try:
8
+ pipeline = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-3-medium")
9
+ return pipeline
10
+ except Exception as e:
11
  st.error(f"Error loading model: {e}")
 
 
 
 
 
12
 
13
+ def generate_response(prompt_text, pipeline):
14
+ try:
15
+ response = pipeline(prompt_text, top_p=0.9, max_length=100)[0]['generated_text']
16
+ return response
17
+ except Exception as e:
18
+ st.error(f"Error generating response: {e}")
19
 
20
  def main():
21
+ st.title('Hugging Face Diffusion Model')
22
+
23
+ # Load the model
24
+ pipeline = load_diffusion_pipeline()
25
 
26
  # Text input for the prompt
27
  prompt_text = st.text_area("Enter your prompt here:", height=200)
 
30
  if st.button("Generate"):
31
  if prompt_text:
32
  with st.spinner('Generating...'):
33
+ generated_text = generate_response(prompt_text, pipeline)
34
  st.success('Generation complete!')
35
  st.text_area('Generated Text:', value=generated_text, height=400)
36
  else: