Image_Gen / app.py
gskdsrikrishna's picture
Update app.py
b5ab290 verified
raw
history blame
1.99 kB
import gradio as gr
import torch
from transformers import CLIPProcessor, CLIPModel
from torch import nn
import numpy as np
import PIL
from PIL import Image
from torchvision import transforms
# Load CLIP model and processor
model_name = "openai/clip-vit-base-patch16"
clip_model = CLIPModel.from_pretrained(model_name)
clip_processor = CLIPProcessor.from_pretrained(model_name)
# Generate a random noise tensor (this will be transformed into an image)
def generate_image_from_text(text_input):
# Preprocess the input text for CLIP model
inputs = clip_processor(text=text_input, return_tensors="pt", padding=True)
# Extract image-text features using CLIP
text_features = clip_model.get_text_features(**inputs)
# Create a simple GAN-like generator using a random noise tensor
class SimpleGenerator(nn.Module):
def __init__(self):
super(SimpleGenerator, self).__init__()
self.fc = nn.Linear(512, 256*256*3) # Adjust output size to match image dimensions
self.relu = nn.ReLU()
def forward(self, z):
x = self.fc(z)
x = self.relu(x)
x = x.view(-1, 3, 256, 256) # Reshape to match image shape
return x
# Initialize the generator
generator = SimpleGenerator()
# Generate random noise based on the text features
random_input = torch.randn(1, 512) # Matching CLIP output size (text_features shape)
generated_image_tensor = generator(random_input)
# Convert generated image tensor to PIL Image
generated_image = generated_image_tensor.squeeze().permute(1, 2, 0).detach().numpy()
generated_image = np.clip(generated_image, 0, 1) # Normalize pixel values
generated_image = (generated_image * 255).astype(np.uint8)
generated_image = Image.fromarray(generated_image)
return generated_image
# Gradio interface
iface = gr.Interface(fn=generate_image_from_text, inputs="text", outputs="image", live=True)
iface.launch()