Spaces:
Running
Running
import gradio as gr | |
import torch | |
from transformers import CLIPProcessor, CLIPModel | |
from torch import nn | |
import numpy as np | |
from PIL import Image | |
import torchvision.transforms as transforms | |
# Load CLIP model and processor | |
model_name = "openai/clip-vit-base-patch16" | |
clip_model = CLIPModel.from_pretrained(model_name) | |
clip_processor = CLIPProcessor.from_pretrained(model_name) | |
# Define a simple generator network | |
class SimpleGenerator(nn.Module): | |
def __init__(self): | |
super(SimpleGenerator, self).__init__() | |
self.fc = nn.Sequential( | |
nn.Linear(512, 1024), | |
nn.ReLU(), | |
nn.Linear(1024, 256*256*3), # Output image pixels | |
nn.Tanh() # Normalize output between -1 and 1 | |
) | |
def forward(self, z): | |
x = self.fc(z) | |
x = x.view(256, 256, 3) # Reshape to image format | |
return x | |
# Initialize the generator model | |
generator = SimpleGenerator() | |
# Function to generate an image based on text input | |
def generate_image_from_text(text_input): | |
# Preprocess text input using CLIP | |
inputs = clip_processor(text=[text_input], return_tensors="pt", padding=True) | |
text_features = clip_model.get_text_features(**inputs) | |
# Generate image tensor | |
with torch.no_grad(): | |
generated_image_tensor = generator(text_features) | |
# Normalize tensor to (0, 255) | |
generated_image = (generated_image_tensor - generated_image_tensor.min()) / (generated_image_tensor.max() - generated_image_tensor.min()) | |
generated_image = (generated_image * 255).cpu().numpy().astype(np.uint8) | |
# Convert to PIL Image | |
image = Image.fromarray(generated_image) | |
return image | |
# Gradio interface | |
iface = gr.Interface(fn=generate_image_from_text, inputs="text", outputs="image", live=True) | |
iface.launch() | |