import torch.multiprocessing as mp import torch import os import re import random from collections import deque from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer import gradio as gr from accelerate import Accelerator import spaces # Check if the start method has already been set if mp.get_start_method(allow_none=True) != 'spawn': mp.set_start_method('spawn') # Instantiate the Accelerator accelerator = Accelerator() dtype = torch.bfloat16 # Set environment variables for local path os.environ['FLUX_DEV'] = '.' os.environ['AE'] = '.' # Seed words pool seed_words = [] used_words = set() # Queue to store parsed descriptions parsed_descriptions_queue = deque() # Usage limits MAX_DESCRIPTIONS = 30 MAX_IMAGES = 3 def initialize_diffusers(): from optimum.quanto import freeze, qfloat8, quantize from diffusers import FlowMatchEulerDiscreteScheduler, AutoencoderKL from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel from diffusers.pipelines.flux.pipeline_flux import FluxPipeline from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast bfl_repo = 'black-forest-labs/FLUX.1-schnell' revision = 'refs/pr/1' scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(bfl_repo, subfolder='scheduler', revision=revision) text_encoder = CLIPTextModel.from_pretrained('openai/clip-vit-large-patch14', torch_dtype=dtype) tokenizer = CLIPTokenizer.from_pretrained('openai/clip-vit-large-patch14', torch_dtype=dtype) text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder='text_encoder_2', torch_dtype=dtype, revision=revision) tokenizer_2 = T5TokenizerFast.from_pretrained(bfl_repo, subfolder='tokenizer_2', torch_dtype=dtype, revision=revision) vae = AutoencoderKL.from_pretrained(bfl_repo, subfolder='vae', torch_dtype=dtype, revision=revision) transformer = FluxTransformer2DModel.from_pretrained(bfl_repo, subfolder='transformer', torch_dtype=dtype, revision=revision) quantize(transformer, weights=qfloat8) freeze(transformer) quantize(text_encoder_2, weights=qfloat8) freeze(text_encoder_2) pipe = FluxPipeline( scheduler=scheduler, text_encoder=text_encoder, tokenizer=tokenizer, text_encoder_2=None, tokenizer_2=tokenizer_2, vae=vae, transformer=None, ) pipe.text_encoder_2 = text_encoder_2 pipe.transformer = transformer pipe.enable_model_cpu_offload() return pipe def generate_description_prompt(subject, user_prompt, text_generator): prompt = f"write concise vivid visual description enclosed in brackets like [ ] less than 100 words of {user_prompt} different from {subject}. " try: generated_text = text_generator(prompt, max_length=160, num_return_sequences=1, truncation=True)[0]['generated_text'] generated_description = re.sub(rf'{re.escape(prompt)}\s*', '', generated_text).strip() # Remove the prompt from the generated text return generated_description if generated_description else None except Exception as e: print(f"Error generating description for subject '{subject}': {e}") return None def parse_descriptions(text): descriptions = re.findall(r'\[([^\[\]]+)\]', text) descriptions = [desc.strip() for desc in descriptions if len(desc.split()) >= 3] return descriptions @spaces.GPU def generate_descriptions(user_prompt, seed_words_input, batch_size=100, max_iterations=2): descriptions = [] description_queue = deque() iteration_count = 0 print("Initializing the text generation pipeline with 16-bit precision...") model_name = 'NousResearch/Meta-Llama-3.1-8B-Instruct' model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map='auto') tokenizer = AutoTokenizer.from_pretrained(model_name) text_generator = pipeline('text-generation', model=model, tokenizer=tokenizer) print("Text generation pipeline initialized with 16-bit precision.") seed_words.extend(re.findall(r'"(.*?)"', seed_words_input)) for _ in range(2): # Perform two iterations while iteration_count < max_iterations and len(parsed_descriptions_queue) < MAX_DESCRIPTIONS: available_subjects = [word for word in seed_words if word not in used_words] if not available_subjects: print("No more available subjects to use.") break subject = random.choice(available_subjects) generated_description = generate_description_prompt(subject, user_prompt, text_generator) if generated_description: clean_description = generated_description.encode('ascii', 'ignore').decode('ascii') description_queue.append({'subject': subject, 'description': clean_description}) print(f"Generated description for subject '{subject}': {clean_description}") used_words.add(subject) seed_words.append(clean_description) parsed_descriptions = parse_descriptions(clean_description) parsed_descriptions_queue.extend(parsed_descriptions) iteration_count += 1 return list(parsed_descriptions_queue) @spaces.GPU(duration=120) def generate_images(parsed_descriptions, max_iterations=3): pipe = initialize_diffusers() if len(parsed_descriptions) < MAX_IMAGES: prompts = parsed_descriptions else: prompts = [parsed_descriptions.pop(0) for _ in range(MAX_IMAGES)] images = [] for prompt in prompts: images.extend(pipe(prompt, num_images=1, num_inference_steps=max_iterations, height=1024, width=1024).images) # Define the resolution here return images def combined_function(user_prompt, seed_words_input): parsed_descriptions = generate_descriptions(user_prompt, seed_words_input) images = generate_images(parsed_descriptions) return parsed_descriptions, images if __name__ == '__main__': def generate_and_display(user_prompt, seed_words_input): parsed_descriptions, images = combined_function(user_prompt, seed_words_input) return parsed_descriptions, images interface = gr.Interface( fn=generate_and_display, inputs=[gr.Textbox(lines=2, placeholder="Enter a prompt for descriptions..."), gr.Textbox(lines=2, placeholder='Enter seed words in quotes, e.g., "cat", "dog", "sunset"...')], outputs=[gr.Textbox(label="Generated Descriptions"), gr.Gallery(label="Generated Images")], live=False, # Set live to False allow_flagging='never', # Disable flagging allow_screenshot=False, # Disable screenshots clear_button=True # Add a clear button ) interface.launch(share=True)