import requests import io from PIL import Image import gradio as gr from transformers import MarianMTModel, MarianTokenizer, AutoModelForCausalLM, AutoTokenizer import os # Load the translation model model_name = "Helsinki-NLP/opus-mt-mul-en" translation_model = MarianMTModel.from_pretrained(model_name) translation_tokenizer = MarianTokenizer.from_pretrained(model_name) # Load GPT-2 model and tokenizer (smaller and faster than GPT-Neo) gpt_model_name = "gpt2" gpt_tokenizer = AutoTokenizer.from_pretrained(gpt_model_name) gpt_model = AutoModelForCausalLM.from_pretrained(gpt_model_name) def translate_text(tamil_text): inputs = translation_tokenizer(tamil_text, return_tensors="pt") translated_tokens = translation_model.generate(**inputs) translation = translation_tokenizer.decode(translated_tokens[0], skip_special_tokens=True) return translation def query_gpt_2(translated_text): prompt = f"Continue the story based on the following text: {translated_text}" inputs = gpt_tokenizer(prompt, return_tensors="pt") outputs = gpt_model.generate(inputs['input_ids'], max_length=50, num_return_sequences=1) # Reduced max_length for speed creative_text = gpt_tokenizer.decode(outputs[0], skip_special_tokens=True) return creative_text def query_image(payload): huggingface_api_key = os.getenv('HUGGINGFACE_API_KEY') if not huggingface_api_key: return "Error: Hugging Face API key not set." API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-dev" headers = {"Authorization": f"Bearer {huggingface_api_key}"} response = requests.post(API_URL, headers=headers, json=payload) if response.status_code == 200: return response.content else: return f"Error: {response.status_code} - {response.text}" def process_input(tamil_input): try: # Translate the input text translated_output = translate_text(tamil_input) # Generate creative text using GPT-2 creative_output = query_gpt_2(translated_output) # Generate an image using Hugging Face's FLUX model image_bytes = query_image({"inputs": translated_output}) image = Image.open(io.BytesIO(image_bytes)) return translated_output, creative_output, image except Exception as e: return f"Error occurred: {str(e)}", "", None # Create a Gradio interface interface = gr.Interface( fn=process_input, inputs=[gr.Textbox(label="Input Tamil Text")], outputs=[ gr.Textbox(label="Translated Text"), gr.Textbox(label="Creative Text"), gr.Image(label="Generated Image") ], title="TRANSART", description="Enter Tamil text to translate to English and generate an image based on the translated text." ) interface.launch()