Image_Gen / app.py
gskdsrikrishna's picture
Update app.py
1598176 verified
import gradio as gr
import torch
from transformers import CLIPProcessor, CLIPModel
from torch import nn
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
# Load CLIP model and processor
model_name = "openai/clip-vit-base-patch16"
clip_model = CLIPModel.from_pretrained(model_name)
clip_processor = CLIPProcessor.from_pretrained(model_name)
# Define a simple generator network
class SimpleGenerator(nn.Module):
def __init__(self):
super(SimpleGenerator, self).__init__()
self.fc = nn.Sequential(
nn.Linear(512, 1024),
nn.ReLU(),
nn.Linear(1024, 256*256*3), # Output image pixels
nn.Tanh() # Normalize output between -1 and 1
)
def forward(self, z):
x = self.fc(z)
x = x.view(256, 256, 3) # Reshape to image format
return x
# Initialize the generator model
generator = SimpleGenerator()
# Function to generate an image based on text input
def generate_image_from_text(text_input):
# Preprocess text input using CLIP
inputs = clip_processor(text=[text_input], return_tensors="pt", padding=True)
text_features = clip_model.get_text_features(**inputs)
# Generate image tensor
with torch.no_grad():
generated_image_tensor = generator(text_features)
# Normalize tensor to (0, 255)
generated_image = (generated_image_tensor - generated_image_tensor.min()) / (generated_image_tensor.max() - generated_image_tensor.min())
generated_image = (generated_image * 255).cpu().numpy().astype(np.uint8)
# Convert to PIL Image
image = Image.fromarray(generated_image)
return image
# Gradio interface
iface = gr.Interface(fn=generate_image_from_text, inputs="text", outputs="image", live=True)
iface.launch()