transart / app.py
pravin0077's picture
Update app.py
57e55c5 verified
raw
history blame
3.81 kB
import os
import concurrent.futures
from huggingface_hub import login
from transformers import MarianMTModel, MarianTokenizer, pipeline
import requests
import io
from PIL import Image
import gradio as gr
# Login with Hugging Face token
hf_token = os.getenv("HUGGINGFACE_API_KEY") # Updated variable name
if hf_token:
login(token=hf_token, add_to_git_credential=True)
else:
raise ValueError("Hugging Face token not found in environment variables.")
# Dynamic translation model loading
def load_translation_model(src_lang, tgt_lang):
model_name = f"Helsinki-NLP/opus-mt-{src_lang}-{tgt_lang}"
tokenizer = MarianTokenizer.from_pretrained(model_name)
model = MarianMTModel.from_pretrained(model_name)
translator = pipeline("translation", model=model, tokenizer=tokenizer)
return translator
# Translation function with reduced max_length
def translate_text(text, src_lang, tgt_lang):
try:
translator = load_translation_model(src_lang, tgt_lang)
translation = translator(text, max_length=20) # Reduced max length for speed
return translation[0]['translation_text']
except Exception as e:
return f"An error occurred: {str(e)}"
# Image generation with reduced resolution
flux_API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-dev"
flux_headers = {"Authorization": f"Bearer {hf_token}"}
def generate_image(prompt):
try:
response = requests.post(flux_API_URL, headers=flux_headers, json={"inputs": prompt})
if response.status_code == 200:
image = Image.open(io.BytesIO(response.content))
image = image.resize((256, 256)) # Reduce resolution for faster processing
return image
else:
return None
except Exception as e:
print(f"Error in image generation: {e}")
return None
# Creative text generation with reduced length
mistral_API_URL = "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-v0.1"
mistral_headers = {"Authorization": f"Bearer {hf_token}"}
def generate_creative_text(translated_text):
try:
response = requests.post(mistral_API_URL, headers=mistral_headers, json={"inputs": translated_text, "max_length": 30})
if response.status_code == 200:
return response.json()[0]['generated_text']
else:
return "Error generating creative text"
except Exception as e:
print(f"Error in creative text generation: {e}")
return None
# Full workflow function with parallel processing
def translate_generate_image_and_text(text, src_lang, tgt_lang):
translated_text = translate_text(text, src_lang, tgt_lang)
with concurrent.futures.ThreadPoolExecutor() as executor:
image_future = executor.submit(generate_image, translated_text)
creative_text_future = executor.submit(generate_creative_text, translated_text)
image = image_future.result()
creative_text = creative_text_future.result()
return translated_text, creative_text, image
# Language options for Gradio dropdown
language_codes = {
"Tamil": "ta", "English": "en", "French": "fr", "Spanish": "es", "German": "de"
}
# Gradio Interface
interface = gr.Interface(
fn=translate_generate_image_and_text,
inputs=[
gr.Textbox(label="Enter text"),
gr.Dropdown(choices=list(language_codes.keys()), label="Source Language", value="Tamil"),
gr.Dropdown(choices=list(language_codes.keys()), label="Target Language", value="English"),
],
outputs=["text", "text", "image"],
title="Multilingual Translation, Image Generation & Creative Text",
description="Translate text between languages, generate images based on translation, and create creative text.",
)
# Launch Gradio app
interface.launch()