deep-div commited on
Commit
1041179
Β·
verified Β·
1 Parent(s): fc9ce6a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -47
app.py CHANGED
@@ -1,21 +1,18 @@
1
  import streamlit as st
2
- import os
3
- import time
4
  import torch
5
  import logging
 
6
  from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration
7
 
8
- # Set Streamlit page configuration
9
- st.set_page_config(page_title="M2M100 Translator")
10
 
11
- # Check device
12
- if torch.cuda.is_available():
13
- device = torch.device("cuda:0")
14
- else:
15
- device = torch.device("cpu")
16
- logging.warning("GPU not found, using CPU, translation will be very slow.")
17
 
18
- # Language code mapping
19
  lang_id = {
20
  "Afrikaans": "af", "Amharic": "am", "Arabic": "ar", "Asturian": "ast",
21
  "Azerbaijani": "az", "Bashkir": "ba", "Belarusian": "be", "Bulgarian": "bg",
@@ -44,58 +41,58 @@ lang_id = {
44
  "Yiddish": "yi", "Yoruba": "yo", "Chinese": "zh", "Zulu": "zu",
45
  }
46
 
47
- # Cache the model and tokenizer using new API
48
  @st.cache_resource
49
- def load_model(pretrained_model="facebook/m2m100_1.2B", cache_dir="models/"):
50
- tokenizer = M2M100Tokenizer.from_pretrained(pretrained_model, cache_dir=cache_dir)
51
- model = M2M100ForConditionalGeneration.from_pretrained(
52
- pretrained_model, cache_dir=cache_dir
53
- ).to(device)
54
  model.eval()
55
  return tokenizer, model
56
 
57
- # App Title and Intro
58
- st.title("🌐 M2M100 Translator")
59
- st.write("""
60
- M2M100 is a multilingual encoder-decoder (seq-to-seq) model trained for Many-to-Many multilingual translation.
61
- It supports **100 languages** and translates in **9900 directions**.
62
- Model: `facebook/m2m100_1.2B`
63
- More info: [Paper](https://arxiv.org/abs/2010.11125) | [Repo](https://github.com/pytorch/fairseq/tree/master/examples/m2m_100)
64
- """)
65
 
66
- # Input Text Area
67
  user_input = st.text_area(
68
- "Enter text to translate:",
69
  height=200,
70
  max_chars=5120,
71
- placeholder="Type your sentence here..."
72
  )
73
 
74
- # Language selectors
75
- source_lang = st.selectbox("Select source language", sorted(lang_id.keys()))
76
- target_lang = st.selectbox("Select target language", sorted(lang_id.keys()))
 
 
 
77
 
78
  # Translate Button
79
- if st.button("Translate"):
80
  with st.spinner("Translating... Please wait"):
81
- time_start = time.time()
82
  tokenizer, model = load_model()
83
 
84
- src_lang = lang_id[source_lang]
85
- trg_lang = lang_id[target_lang]
86
 
87
- tokenizer.src_lang = src_lang
88
  with torch.no_grad():
89
- encoded_input = tokenizer(user_input, return_tensors="pt").to(device)
90
- generated_tokens = model.generate(
91
- **encoded_input,
92
- forced_bos_token_id=tokenizer.get_lang_id(trg_lang)
93
  )
94
- translated_text = tokenizer.batch_decode(
95
- generated_tokens, skip_special_tokens=True
96
- )[0]
 
 
 
 
97
 
98
- time_end = time.time()
99
- st.success("Translation complete!")
100
- st.markdown(f"**Translated Text:**\n\n{translated_text}")
101
- st.caption(f"Time taken: {round(time_end - time_start, 2)} seconds")
 
1
  import streamlit as st
 
 
2
  import torch
3
  import logging
4
+ import time
5
  from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration
6
 
7
+ # Configure page
8
+ st.set_page_config(page_title="🌐 Translator", page_icon="🌐")
9
 
10
+ # Device detection
11
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
12
+ if device.type == "cpu":
13
+ logging.warning("⚠️ GPU not found β€” using CPU (translation may be slow).")
 
 
14
 
15
+ # Language mapping
16
  lang_id = {
17
  "Afrikaans": "af", "Amharic": "am", "Arabic": "ar", "Asturian": "ast",
18
  "Azerbaijani": "az", "Bashkir": "ba", "Belarusian": "be", "Bulgarian": "bg",
 
41
  "Yiddish": "yi", "Yoruba": "yo", "Chinese": "zh", "Zulu": "zu",
42
  }
43
 
44
+ # Cache model/tokenizer loading
45
  @st.cache_resource
46
+ def load_model():
47
+ tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_1.2B")
48
+ model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_1.2B").to(device)
 
 
49
  model.eval()
50
  return tokenizer, model
51
 
52
+ # Title
53
+ st.title("🌍 M2M100 Language Translator")
54
+ st.markdown("πŸ” Translate text between **100+ languages** using Facebook's `M2M100` multilingual model.")
 
 
 
 
 
55
 
56
+ # Text input
57
  user_input = st.text_area(
58
+ "✏️ Enter your text below:",
59
  height=200,
60
  max_chars=5120,
61
+ placeholder="E.g. Hello, how are you?"
62
  )
63
 
64
+ # Language selections (default: English β†’ Hindi)
65
+ col1, col2 = st.columns(2)
66
+ with col1:
67
+ source_lang = st.selectbox("🌐 Source Language", sorted(lang_id.keys()), index=list(lang_id.keys()).index("English"))
68
+ with col2:
69
+ target_lang = st.selectbox("πŸ” Target Language", sorted(lang_id.keys()), index=list(lang_id.keys()).index("Hindi"))
70
 
71
  # Translate Button
72
+ if st.button("πŸš€ Translate", disabled=(not user_input.strip())):
73
  with st.spinner("Translating... Please wait"):
74
+ start = time.time()
75
  tokenizer, model = load_model()
76
 
77
+ src = lang_id[source_lang]
78
+ tgt = lang_id[target_lang]
79
 
80
+ tokenizer.src_lang = src
81
  with torch.no_grad():
82
+ encoded = tokenizer(user_input, return_tensors="pt").to(device)
83
+ output = model.generate(
84
+ **encoded,
85
+ forced_bos_token_id=tokenizer.get_lang_id(tgt)
86
  )
87
+ result = tokenizer.batch_decode(output, skip_special_tokens=True)[0]
88
+
89
+ end = time.time()
90
+ st.success("βœ… Translation complete!")
91
+ st.markdown("### πŸ“ Translated Text")
92
+ st.text_area("Output", value=result, height=150, disabled=True)
93
+ st.caption(f"⏱️ Time taken: {round(end - start, 2)} seconds")
94
 
95
+ # Optional reset
96
+ st.markdown("---")
97
+ if st.button("πŸ”„ Reset"):
98
+ st.experimental_rerun()