gskdsrikrishna commited on
Commit
1598176
·
verified ·
1 Parent(s): 1f9e487

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -34
app.py CHANGED
@@ -3,50 +3,51 @@ import torch
3
  from transformers import CLIPProcessor, CLIPModel
4
  from torch import nn
5
  import numpy as np
6
- import PIL
7
  from PIL import Image
8
- from torchvision import transforms
9
 
10
  # Load CLIP model and processor
11
  model_name = "openai/clip-vit-base-patch16"
12
  clip_model = CLIPModel.from_pretrained(model_name)
13
  clip_processor = CLIPProcessor.from_pretrained(model_name)
14
 
15
- # Generate a random noise tensor (this will be transformed into an image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  def generate_image_from_text(text_input):
17
- # Preprocess the input text for CLIP model
18
- inputs = clip_processor(text=text_input, return_tensors="pt", padding=True)
19
-
20
- # Extract image-text features using CLIP
21
  text_features = clip_model.get_text_features(**inputs)
22
 
23
- # Create a simple GAN-like generator using a random noise tensor
24
- class SimpleGenerator(nn.Module):
25
- def __init__(self):
26
- super(SimpleGenerator, self).__init__()
27
- self.fc = nn.Linear(512, 256*256*3) # Adjust output size to match image dimensions
28
- self.relu = nn.ReLU()
29
-
30
- def forward(self, z):
31
- x = self.fc(z)
32
- x = self.relu(x)
33
- x = x.view(-1, 3, 256, 256) # Reshape to match image shape
34
- return x
35
-
36
- # Initialize the generator
37
- generator = SimpleGenerator()
38
-
39
- # Generate random noise based on the text features
40
- random_input = torch.randn(1, 512) # Matching CLIP output size (text_features shape)
41
- generated_image_tensor = generator(random_input)
42
-
43
- # Convert generated image tensor to PIL Image
44
- generated_image = generated_image_tensor.squeeze().permute(1, 2, 0).detach().numpy()
45
- generated_image = np.clip(generated_image, 0, 1) # Normalize pixel values
46
- generated_image = (generated_image * 255).astype(np.uint8)
47
- generated_image = Image.fromarray(generated_image)
48
-
49
- return generated_image
50
 
51
  # Gradio interface
52
  iface = gr.Interface(fn=generate_image_from_text, inputs="text", outputs="image", live=True)
 
3
  from transformers import CLIPProcessor, CLIPModel
4
  from torch import nn
5
  import numpy as np
 
6
  from PIL import Image
7
+ import torchvision.transforms as transforms
8
 
9
  # Load CLIP model and processor
10
  model_name = "openai/clip-vit-base-patch16"
11
  clip_model = CLIPModel.from_pretrained(model_name)
12
  clip_processor = CLIPProcessor.from_pretrained(model_name)
13
 
14
+ # Define a simple generator network
15
+ class SimpleGenerator(nn.Module):
16
+ def __init__(self):
17
+ super(SimpleGenerator, self).__init__()
18
+ self.fc = nn.Sequential(
19
+ nn.Linear(512, 1024),
20
+ nn.ReLU(),
21
+ nn.Linear(1024, 256*256*3), # Output image pixels
22
+ nn.Tanh() # Normalize output between -1 and 1
23
+ )
24
+
25
+ def forward(self, z):
26
+ x = self.fc(z)
27
+ x = x.view(256, 256, 3) # Reshape to image format
28
+ return x
29
+
30
+ # Initialize the generator model
31
+ generator = SimpleGenerator()
32
+
33
+ # Function to generate an image based on text input
34
  def generate_image_from_text(text_input):
35
+ # Preprocess text input using CLIP
36
+ inputs = clip_processor(text=[text_input], return_tensors="pt", padding=True)
 
 
37
  text_features = clip_model.get_text_features(**inputs)
38
 
39
+ # Generate image tensor
40
+ with torch.no_grad():
41
+ generated_image_tensor = generator(text_features)
42
+
43
+ # Normalize tensor to (0, 255)
44
+ generated_image = (generated_image_tensor - generated_image_tensor.min()) / (generated_image_tensor.max() - generated_image_tensor.min())
45
+ generated_image = (generated_image * 255).cpu().numpy().astype(np.uint8)
46
+
47
+ # Convert to PIL Image
48
+ image = Image.fromarray(generated_image)
49
+
50
+ return image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  # Gradio interface
53
  iface = gr.Interface(fn=generate_image_from_text, inputs="text", outputs="image", live=True)