import gradio as gr import torch from transformers import CLIPProcessor, CLIPModel from torch import nn import numpy as np from PIL import Image import torchvision.transforms as transforms # Load CLIP model and processor model_name = "openai/clip-vit-base-patch16" clip_model = CLIPModel.from_pretrained(model_name) clip_processor = CLIPProcessor.from_pretrained(model_name) # Define a simple generator network class SimpleGenerator(nn.Module): def __init__(self): super(SimpleGenerator, self).__init__() self.fc = nn.Sequential( nn.Linear(512, 1024), nn.ReLU(), nn.Linear(1024, 256*256*3), # Output image pixels nn.Tanh() # Normalize output between -1 and 1 ) def forward(self, z): x = self.fc(z) x = x.view(256, 256, 3) # Reshape to image format return x # Initialize the generator model generator = SimpleGenerator() # Function to generate an image based on text input def generate_image_from_text(text_input): # Preprocess text input using CLIP inputs = clip_processor(text=[text_input], return_tensors="pt", padding=True) text_features = clip_model.get_text_features(**inputs) # Generate image tensor with torch.no_grad(): generated_image_tensor = generator(text_features) # Normalize tensor to (0, 255) generated_image = (generated_image_tensor - generated_image_tensor.min()) / (generated_image_tensor.max() - generated_image_tensor.min()) generated_image = (generated_image * 255).cpu().numpy().astype(np.uint8) # Convert to PIL Image image = Image.fromarray(generated_image) return image # Gradio interface iface = gr.Interface(fn=generate_image_from_text, inputs="text", outputs="image", live=True) iface.launch()