import torch.multiprocessing as mp import torch import os import re import random from collections import deque from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer import gradio as gr from accelerate import Accelerator # Check if the start method has already been set if mp.get_start_method(allow_none=True) != 'spawn': mp.set_start_method('spawn') # Instantiate the Accelerator accelerator = Accelerator() dtype = torch.bfloat16 # Set environment variables for local path os.environ['FLUX_DEV'] = '.' os.environ['AE'] = '.' # Seed words pool seed_words = [] used_words = set() # Queue to store parsed descriptions parsed_descriptions_queue = deque() # Usage limits MAX_DESCRIPTIONS = 30 MAX_IMAGES = 12 def initialize_diffusers(): from optimum.quanto import freeze, qfloat8, quantize from diffusers import FlowMatchEulerDiscreteScheduler, AutoencoderKL from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel from diffusers.pipelines.flux.pipeline_flux import FluxPipeline from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast bfl_repo = 'black-forest-labs/FLUX.1-schnell' revision = 'refs/pr/1' scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(bfl_repo, subfolder='scheduler', revision=revision) text_encoder = CLIPTextModel.from_pretrained('openai/clip-vit-large-patch14', torch_dtype=dtype) tokenizer = CLIPTokenizer.from_pretrained('openai/clip-vit-large-patch14', torch_dtype=dtype) text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder='text_encoder_2', torch_dtype=dtype, revision=revision) tokenizer_2 = T5TokenizerFast.from_pretrained(bfl_repo, subfolder='tokenizer_2', torch_dtype=dtype, revision=revision) vae = AutoencoderKL.from_pretrained(bfl_repo, subfolder='vae', torch_dtype=dtype, revision=revision) transformer = FluxTransformer2DModel.from_pretrained(bfl_repo, subfolder='transformer', torch_dtype=dtype, revision=revision) quantize(transformer, weights=qfloat8) freeze(transformer) quantize(text_encoder_2, weights=qfloat8) freeze(text_encoder_2) pipe = FluxPipeline( scheduler=scheduler, text_encoder=text_encoder, tokenizer=tokenizer, text_encoder_2=None, tokenizer_2=tokenizer_2, vae=vae, transformer=None, ) pipe.text_encoder_2 = text_encoder_2 pipe.transformer = transformer pipe.enable_model_cpu_offload() return pipe def generate_description_prompt(subject, user_prompt, text_generator): prompt = f"write concise vivid visual description enclosed in brackets like [ ] less than 100 words of {user_prompt} different from {subject}. " try: generated_text = text_generator(prompt, max_length=160, num_return_sequences=1, truncation=True)[0]['generated_text'] generated_description = re.sub(rf'{re.escape(prompt)}\s*', '', generated_text).strip() # Remove the prompt from the generated text return generated_description if generated_description else None except Exception as e: print(f"Error generating description for subject '{subject}': {e}") return None def parse_descriptions(text): descriptions = re.findall(r'\[([^\[\]]+)\]', text) descriptions = [desc.strip() for desc in descriptions if len(desc.split()) >= 3] return descriptions @spaces.GPU def generate_descriptions(user_prompt, seed_words_input, batch_size=100, max_iterations=50): descriptions = [] description_queue = deque() iteration_count = 0 print("Initializing the text generation pipeline with 16-bit precision...") model_name = 'meta-llama/Meta-Llama-3.1-8B-Instruct' model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map='auto') tokenizer = AutoTokenizer.from_pretrained(model_name) text_generator = pipeline('text-generation', model=model, tokenizer=tokenizer) print("Text generation pipeline initialized with 16-bit precision.") seed_words.extend(re.findall(r'"(.*?)"', seed_words_input)) while iteration_count < max_iterations and len(parsed_descriptions_queue) < MAX_DESCRIPTIONS: available_subjects = [word for word in seed_words if word not in used_words] if not available_subjects: print("No more available subjects to use.") break subject = random.choice(available_subjects) generated_description = generate_description_prompt(subject, user_prompt, text_generator) if generated_description: clean_description = generated_description.encode('ascii', 'ignore').decode('ascii') description_queue.append({'subject': subject, 'description': clean_description}) print(f"Generated description for subject '{subject}': {clean_description}") used_words.add(subject) seed_words.append(clean_description) if iteration_count % 3 == 0: parsed_descriptions = parse_descriptions(clean_description) parsed_descriptions_queue.extend(parsed_descriptions) iteration_count += 1 return list(parsed_descriptions_queue) @spaces.GPU(duration=120) def generate_images(parsed_descriptions, pipe): if len(parsed_descriptions) < MAX_IMAGES: prompts = parsed_descriptions else: prompts = [parsed_descriptions.pop(0) for _ in range(MAX_IMAGES)] images = [] for prompt in prompts: images.extend(pipe(prompt, num_images=1).images) return images def combined_function(user_prompt, seed_words_input): parsed_descriptions = generate_descriptions(user_prompt, seed_words_input) pipe = initialize_diffusers() images = generate_images(parsed_descriptions, pipe) return images if __name__ == '__main__': torch.cuda.init() interface = gr.Interface( fn=combined_function, inputs=[gr.Textbox(lines=2, placeholder="Enter a prompt for descriptions..."), gr.Textbox(lines=2, placeholder='Enter seed words in quotes, e.g., "cat", "dog", "sunset"...')], outputs=gr.Gallery() ) interface.launch()