Dddixyy commited on
Commit
3fa5eb5
·
verified ·
1 Parent(s): 8575513

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -3
app.py CHANGED
@@ -1,17 +1,33 @@
1
  import gradio as gr
2
  import torch
3
  from transformers import MarianMTModel, MarianTokenizer
 
4
 
5
- # Load the model and tokenizer from the Hub
6
  model_name = "Dddixyy/latin-italian-translator"
 
 
 
 
 
 
 
 
 
 
 
 
7
  tokenizer = MarianTokenizer.from_pretrained(model_name)
8
- model = MarianMTModel.from_pretrained(model_name)
9
 
10
  # Translation function
11
  def translate_latin_to_italian(latin_text):
12
- inputs = tokenizer(latin_text, return_tensors="pt", padding=True, truncation=True)
 
 
13
  with torch.no_grad():
14
  generated_ids = model.generate(inputs["input_ids"])
 
 
15
  translation = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
16
  return translation[0]
17
 
 
1
  import gradio as gr
2
  import torch
3
  from transformers import MarianMTModel, MarianTokenizer
4
+ from optimum.intel import IncQuantizer
5
 
6
+ # Load and optimize the model (quantization)
7
  model_name = "Dddixyy/latin-italian-translator"
8
+
9
+ # Load the quantized model if available or use a regular model (quantization shown as an example)
10
+ try:
11
+ # Attempt to load a quantized version if it's available
12
+ quantizer = IncQuantizer.from_pretrained(model_name)
13
+ model = quantizer.quantize()
14
+ print("Quantized model loaded.")
15
+ except Exception as e:
16
+ print(f"Error loading quantized model: {e}")
17
+ model = MarianMTModel.from_pretrained(model_name)
18
+
19
+ # Load tokenizer
20
  tokenizer = MarianTokenizer.from_pretrained(model_name)
 
21
 
22
  # Translation function
23
  def translate_latin_to_italian(latin_text):
24
+ # Truncate input to 512 tokens to avoid overload (adjust as necessary)
25
+ inputs = tokenizer(latin_text, return_tensors="pt", padding=True, truncation=True, max_length=512)
26
+
27
  with torch.no_grad():
28
  generated_ids = model.generate(inputs["input_ids"])
29
+
30
+ # Decode the generated ids into a readable translation
31
  translation = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
32
  return translation[0]
33