RanM commited on
Commit
e0ec116
·
verified ·
1 Parent(s): 1789a95

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -22
app.py CHANGED
@@ -3,27 +3,19 @@ import torch
3
  from diffusers import AutoPipelineForText2Image
4
  from io import BytesIO
5
  from generate_propmts import generate_prompt
6
- from concurrent.futures import ThreadPoolExecutor
7
  import json
8
 
9
  # Load the model once outside of the function
10
  model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
11
 
12
- # Helper function to truncate prompt to fit the model's maximum sequence length
13
- # def truncate_prompt(prompt, max_length=77):
14
- # return prompt[:max_length]
15
 
16
- def generate_image(text, sentence_mapping, character_dict, selected_style):
17
  try:
18
- prompt = generate_prompt(text, sentence_mapping, character_dict, selected_style)
19
- print(f"Generated prompt: {prompt}")
20
  # Truncate prompt if necessary
21
- # prompt = truncate_prompt(prompt)
22
- # print(f"truncate_prompt: {prompt}")
23
- output = model(prompt=prompt, num_inference_steps=1, guidance_scale=0.0)
24
  print(f"Model output: {output}")
25
- print("len of output:", len(output))
26
- print("output.images:", output.images)
27
  # Check if the model returned images
28
  if output.images:
29
  image = output.images[0]
@@ -41,18 +33,27 @@ def generate_image(text, sentence_mapping, character_dict, selected_style):
41
  def inference(sentence_mapping, character_dict, selected_style):
42
  images = {}
43
  print(f'sentence_mapping: {sentence_mapping}, character_dict: {character_dict}, selected_style: {selected_style}')
44
- # Here we assume `sentence_mapping` is a dictionary where keys are paragraph numbers and values are lists of sentences
45
- grouped_sentences = sentence_mapping
 
 
 
 
 
 
46
 
47
  with ThreadPoolExecutor() as executor:
48
- futures = {}
49
- for paragraph_number, sentences in grouped_sentences.items():
50
- combined_sentence = " ".join(sentences)
51
- futures[paragraph_number] = executor.submit(generate_image, combined_sentence, sentence_mapping, character_dict, selected_style)
 
 
 
 
 
 
52
 
53
- for paragraph_number, future in futures.items():
54
- images[paragraph_number] = future.result()
55
- print(f'images:{images}')
56
  return images
57
 
58
  gradio_interface = gr.Interface(
@@ -60,7 +61,7 @@ gradio_interface = gr.Interface(
60
  inputs=[
61
  gr.JSON(label="Sentence Mapping"),
62
  gr.JSON(label="Character Dict"),
63
- gr.Dropdown(["Style 1", "Style 2", "Style 3"], label="Selected Style")
64
  ],
65
  outputs="json"
66
  )
 
3
  from diffusers import AutoPipelineForText2Image
4
  from io import BytesIO
5
  from generate_propmts import generate_prompt
6
+ from concurrent.futures import ThreadPoolExecutor, as_completed
7
  import json
8
 
9
  # Load the model once outside of the function
10
  model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
11
 
 
 
 
12
 
13
+ def generate_image(prompt):
14
  try:
 
 
15
  # Truncate prompt if necessary
16
+ output = model(prompt=prompt, num_inference_steps=1, guidance_scale=0.0).images[0]
 
 
17
  print(f"Model output: {output}")
18
+
 
19
  # Check if the model returned images
20
  if output.images:
21
  image = output.images[0]
 
33
  def inference(sentence_mapping, character_dict, selected_style):
34
  images = {}
35
  print(f'sentence_mapping: {sentence_mapping}, character_dict: {character_dict}, selected_style: {selected_style}')
36
+ prompts = []
37
+
38
+ # Generate prompts for each paragraph
39
+ for paragraph_number, sentences in sentence_mapping.items():
40
+ combined_sentence = " ".join(sentences)
41
+ prompt = generate_prompt(combined_sentence, sentence_mapping, character_dict, selected_style)
42
+ prompts.append((paragraph_number, prompt))
43
+ print(f"Generated prompt for paragraph {paragraph_number}: {prompt}")
44
 
45
  with ThreadPoolExecutor() as executor:
46
+ future_to_paragraph = {executor.submit(generate_image, prompt): paragraph_number for paragraph_number, prompt in prompts}
47
+
48
+ for future in as_completed(future_to_paragraph):
49
+ paragraph_number = future_to_paragraph[future]
50
+ try:
51
+ image = future.result()
52
+ if image:
53
+ images[paragraph_number] = image
54
+ except Exception as e:
55
+ print(f"Error processing paragraph {paragraph_number}: {e}")
56
 
 
 
 
57
  return images
58
 
59
  gradio_interface = gr.Interface(
 
61
  inputs=[
62
  gr.JSON(label="Sentence Mapping"),
63
  gr.JSON(label="Character Dict"),
64
+ gr.Dropdown(["oil painting", "sketch", "watercolor"], label="Selected Style")
65
  ],
66
  outputs="json"
67
  )