pravin0077 commited on
Commit
7c1115e
·
verified ·
1 Parent(s): 4b04990

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -101
app.py CHANGED
@@ -1,115 +1,93 @@
1
- # Import necessary libraries
 
 
 
2
  import requests
3
  import io
4
  from PIL import Image
5
- import matplotlib.pyplot as plt
6
- from transformers import MarianMTModel, MarianTokenizer, pipeline
7
- from transformers import AutoTokenizer, AutoModelForCausalLM
8
  import gradio as gr
9
- import os # For accessing environment variables
10
-
11
- # Constants for model names and API URLs
12
- class Constants:
13
- TRANSLATION_MODEL_NAME = "Helsinki-NLP/opus-mt-mul-en"
14
- IMAGE_GENERATION_API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-dev"
15
- GPT_NEO_MODEL_NAME = "EleutherAI/gpt-neo-125M"
16
- # Get the Hugging Face API token from environment variables
17
- HEADERS = {"Authorization": f"Bearer {os.getenv('HUGGINGFACE_API_KEY')}"}
18
-
19
- # Translation Class
20
- class Translator:
21
- def __init__(self):
22
- self.tokenizer = MarianTokenizer.from_pretrained(Constants.TRANSLATION_MODEL_NAME)
23
- self.model = MarianMTModel.from_pretrained(Constants.TRANSLATION_MODEL_NAME)
24
- self.pipeline = pipeline("translation", model=self.model, tokenizer=self.tokenizer)
25
-
26
- def translate(self, tamil_text):
27
- """Translate Tamil text to English."""
28
- try:
29
- translation = self.pipeline(tamil_text, max_length=40)
30
- return translation[0]['translation_text']
31
- except Exception as e:
32
- return f"Translation error: {str(e)}"
33
 
34
-
35
- # Image Generation Class
36
- class ImageGenerator:
37
- def __init__(self):
38
- self.api_url = Constants.IMAGE_GENERATION_API_URL
39
-
40
- def generate(self, prompt):
41
- """Generate an image based on the given prompt."""
42
- try:
43
- response = requests.post(self.api_url, headers=Constants.HEADERS, json={"inputs": prompt})
44
- if response.status_code == 200:
45
- image_bytes = response.content
46
- return Image.open(io.BytesIO(image_bytes))
47
- else:
48
- print(f"Image generation failed: Status code {response.status_code}")
49
- return None
50
- except Exception as e:
51
- print(f"Image generation error: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  return None
53
-
54
-
55
- # Creative Text Generation Class
56
- class CreativeTextGenerator:
57
- def __init__(self):
58
- self.tokenizer = AutoTokenizer.from_pretrained(Constants.GPT_NEO_MODEL_NAME)
59
- self.model = AutoModelForCausalLM.from_pretrained(Constants.GPT_NEO_MODEL_NAME)
60
-
61
- def generate(self, translated_text):
62
- """Generate creative text based on translated text."""
63
- input_ids = self.tokenizer(translated_text, return_tensors='pt').input_ids
64
- generated_text_ids = self.model.generate(input_ids, max_length=100)
65
- return self.tokenizer.decode(generated_text_ids[0], skip_special_tokens=True)
66
-
67
-
68
- # Main Application Class
69
- class TransArtApp:
70
- def __init__(self):
71
- self.translator = Translator()
72
- self.image_generator = ImageGenerator()
73
- self.creative_text_generator = CreativeTextGenerator()
74
-
75
- def process(self, tamil_text):
76
- """Handle the full workflow: translate, generate image, and creative text."""
77
- translated_text = self.translator.translate(tamil_text)
78
- image = self.image_generator.generate(translated_text)
79
- creative_text = self.creative_text_generator.generate(translated_text)
80
- return translated_text, creative_text, image
81
-
82
-
83
- # Function to display images
84
- def show_image(image):
85
- """Display an image using matplotlib."""
86
- if image:
87
- plt.imshow(image)
88
- plt.axis('off') # Hide axes
89
- plt.show()
90
- else:
91
- print("No image to display.")
92
-
93
-
94
- # Create an instance of the TransArt app
95
- app = TransArtApp()
96
-
97
- # Gradio interface function
98
- def gradio_interface(tamil_text):
99
- """Interface function for Gradio."""
100
- translated_text, creative_text, image = app.process(tamil_text)
101
  return translated_text, creative_text, image
102
 
 
 
 
 
103
 
104
- # Create Gradio interface
105
  interface = gr.Interface(
106
- fn=gradio_interface,
107
- inputs="text",
 
 
 
 
108
  outputs=["text", "text", "image"],
109
- title="Tamil to English Translation, Image Generation & Creative Text",
110
- description="Enter Tamil text to translate to English, generate an image, and create creative text based on the translation."
111
  )
112
 
113
  # Launch Gradio app
114
- if __name__ == "__main__":
115
- interface.launch()
 
1
+ import os
2
+ import concurrent.futures
3
+ from huggingface_hub import login
4
+ from transformers import MarianMTModel, MarianTokenizer, pipeline
5
  import requests
6
  import io
7
  from PIL import Image
 
 
 
8
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ # Login with Hugging Face token
11
+ hf_token = os.getenv("HF_TOKEN")
12
+ if hf_token:
13
+ login(token=hf_token, add_to_git_credential=True)
14
+ else:
15
+ raise ValueError("Hugging Face token not found in environment variables.")
16
+
17
+ # Dynamic translation model loading
18
+ def load_translation_model(src_lang, tgt_lang):
19
+ model_name = f"Helsinki-NLP/opus-mt-{src_lang}-{tgt_lang}"
20
+ tokenizer = MarianTokenizer.from_pretrained(model_name)
21
+ model = MarianMTModel.from_pretrained(model_name)
22
+ translator = pipeline("translation", model=model, tokenizer=tokenizer)
23
+ return translator
24
+
25
+ # Translation function with reduced max_length
26
+ def translate_text(text, src_lang, tgt_lang):
27
+ try:
28
+ translator = load_translation_model(src_lang, tgt_lang)
29
+ translation = translator(text, max_length=20) # Reduced max length for speed
30
+ return translation[0]['translation_text']
31
+ except Exception as e:
32
+ return f"An error occurred: {str(e)}"
33
+
34
+ # Image generation with reduced resolution
35
+ flux_API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-dev"
36
+ flux_headers = {"Authorization": f"Bearer {hf_token}"}
37
+ def generate_image(prompt):
38
+ try:
39
+ response = requests.post(flux_API_URL, headers=flux_headers, json={"inputs": prompt})
40
+ if response.status_code == 200:
41
+ image = Image.open(io.BytesIO(response.content))
42
+ image = image.resize((256, 256)) # Reduce resolution for faster processing
43
+ return image
44
+ else:
45
  return None
46
+ except Exception as e:
47
+ print(f"Error in image generation: {e}")
48
+ return None
49
+
50
+ # Creative text generation with reduced length
51
+ mistral_API_URL = "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-v0.1"
52
+ mistral_headers = {"Authorization": f"Bearer {hf_token}"}
53
+ def generate_creative_text(translated_text):
54
+ try:
55
+ response = requests.post(mistral_API_URL, headers=mistral_headers, json={"inputs": translated_text, "max_length": 30})
56
+ if response.status_code == 200:
57
+ return response.json()[0]['generated_text']
58
+ else:
59
+ return "Error generating creative text"
60
+ except Exception as e:
61
+ print(f"Error in creative text generation: {e}")
62
+ return None
63
+
64
+ # Full workflow function with parallel processing
65
+ def translate_generate_image_and_text(text, src_lang, tgt_lang):
66
+ translated_text = translate_text(text, src_lang, tgt_lang)
67
+ with concurrent.futures.ThreadPoolExecutor() as executor:
68
+ image_future = executor.submit(generate_image, translated_text)
69
+ creative_text_future = executor.submit(generate_creative_text, translated_text)
70
+ image = image_future.result()
71
+ creative_text = creative_text_future.result()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  return translated_text, creative_text, image
73
 
74
+ # Language options for Gradio dropdown
75
+ language_codes = {
76
+ "Tamil": "ta", "English": "en", "French": "fr", "Spanish": "es", "German": "de"
77
+ }
78
 
79
+ # Gradio Interface
80
  interface = gr.Interface(
81
+ fn=translate_generate_image_and_text,
82
+ inputs=[
83
+ gr.Textbox(label="Enter text"),
84
+ gr.Dropdown(choices=list(language_codes.keys()), label="Source Language", default="Tamil"),
85
+ gr.Dropdown(choices=list(language_codes.keys()), label="Target Language", default="English"),
86
+ ],
87
  outputs=["text", "text", "image"],
88
+ title="Multilingual Translation, Image Generation & Creative Text",
89
+ description="Translate text between languages, generate images based on translation, and create creative text.",
90
  )
91
 
92
  # Launch Gradio app
93
+ interface.launch()