RanM commited on
Commit
1adc78a
·
verified ·
1 Parent(s): 3486e1a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -14
app.py CHANGED
@@ -1,4 +1,3 @@
1
- import json
2
  import gradio as gr
3
  import torch
4
  from diffusers import AutoPipelineForText2Image
@@ -6,6 +5,7 @@ import base64
6
  from io import BytesIO
7
  from generate_propmts import generate_prompt
8
  from concurrent.futures import ThreadPoolExecutor
 
9
 
10
  # Load the model once outside of the function
11
  model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
@@ -15,9 +15,10 @@ def generate_image(text, sentence_mapping, character_dict, selected_style):
15
  prompt, _ = generate_prompt(text, sentence_mapping, character_dict, selected_style)
16
  image = model(prompt=prompt, num_inference_steps=1, guidance_scale=0.0).images[0]
17
  buffered = BytesIO()
18
- buffered.write(image.tobytes())
19
  img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
20
- return img_str
 
 
21
  except Exception as e:
22
  print(f"Error generating image: {e}")
23
  return None
@@ -25,19 +26,14 @@ def generate_image(text, sentence_mapping, character_dict, selected_style):
25
  def inference(sentence_mapping, character_dict, selected_style):
26
  images = {}
27
  print(f'sentence_mapping:{sentence_mapping}, character_dict:{character_dict}, selected_style:{selected_style}')
28
-
29
- # Parse sentence_mapping JSON string into a dictionary
30
- try:
31
- grouped_sentences = json.loads(sentence_mapping)
32
- except json.JSONDecodeError as e:
33
- print(f"Error parsing JSON: {e}")
34
- return {"error": "Invalid JSON input for sentence_mapping"}
35
 
36
  with ThreadPoolExecutor() as executor:
37
  futures = {}
38
  for paragraph_number, sentences in grouped_sentences.items():
39
  combined_sentence = " ".join(sentences)
40
- futures[paragraph_number] = executor.submit(generate_image, combined_sentence, grouped_sentences, character_dict, selected_style)
41
 
42
  for paragraph_number, future in futures.items():
43
  images[paragraph_number] = future.result()
@@ -47,11 +43,11 @@ def inference(sentence_mapping, character_dict, selected_style):
47
  gradio_interface = gr.Interface(
48
  fn=inference,
49
  inputs=[
50
- gr.Textbox(label="Sentence Mapping (JSON)"),
51
- gr.Textbox(label="Character Dict (JSON)"),
52
  gr.Dropdown(["Style 1", "Style 2", "Style 3"], label="Selected Style")
53
  ],
54
- outputs="text"
55
  )
56
 
57
  if __name__ == "__main__":
 
 
1
  import gradio as gr
2
  import torch
3
  from diffusers import AutoPipelineForText2Image
 
5
  from io import BytesIO
6
  from generate_propmts import generate_prompt
7
  from concurrent.futures import ThreadPoolExecutor
8
+ import json
9
 
10
  # Load the model once outside of the function
11
  model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
 
15
  prompt, _ = generate_prompt(text, sentence_mapping, character_dict, selected_style)
16
  image = model(prompt=prompt, num_inference_steps=1, guidance_scale=0.0).images[0]
17
  buffered = BytesIO()
 
18
  img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
19
+ if isinstance(result, img_str):
20
+ image_bytes = base64.b64decode(result)
21
+ return image_bytes
22
  except Exception as e:
23
  print(f"Error generating image: {e}")
24
  return None
 
26
  def inference(sentence_mapping, character_dict, selected_style):
27
  images = {}
28
  print(f'sentence_mapping:{sentence_mapping}, character_dict:{character_dict}, selected_style:{selected_style}')
29
+ # Here we assume `sentence_mapping` is a dictionary where keys are paragraph numbers and values are lists of sentences
30
+ grouped_sentences = sentence_mapping
 
 
 
 
 
31
 
32
  with ThreadPoolExecutor() as executor:
33
  futures = {}
34
  for paragraph_number, sentences in grouped_sentences.items():
35
  combined_sentence = " ".join(sentences)
36
+ futures[paragraph_number] = executor.submit(generate_image, combined_sentence, sentence_mapping, character_dict, selected_style)
37
 
38
  for paragraph_number, future in futures.items():
39
  images[paragraph_number] = future.result()
 
43
  gradio_interface = gr.Interface(
44
  fn=inference,
45
  inputs=[
46
+ gr.JSON(label="Sentence Mapping"),
47
+ gr.JSON(label="Character Dict"),
48
  gr.Dropdown(["Style 1", "Style 2", "Style 3"], label="Selected Style")
49
  ],
50
+ outputs="json"
51
  )
52
 
53
  if __name__ == "__main__":