gokilashree commited on
Commit
f8a580e
·
verified ·
1 Parent(s): e55a3d1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -1
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from transformers import MBartForConditionalGeneration, MBart50Tokenizer, AutoModelForCausalLM, AutoTokenizer, pipeline
2
  import gradio as gr
3
  import requests
@@ -20,7 +21,7 @@ tokenizer = MBart50Tokenizer.from_pretrained(translation_model_name)
20
  translation_model = MBartForConditionalGeneration.from_pretrained(translation_model_name)
21
 
22
  # Load a text generation model from Hugging Face
23
- text_generation_model_name = "EleutherAI/gpt-neo-2.7B" # Use "EleutherAI/gpt-j-6B" for better quality
24
  text_tokenizer = AutoTokenizer.from_pretrained(text_generation_model_name)
25
  text_model = AutoModelForCausalLM.from_pretrained(text_generation_model_name, device_map="auto", torch_dtype=torch.float32)
26
 
 
1
+ import torch # Explicitly import torch to avoid import issues
2
  from transformers import MBartForConditionalGeneration, MBart50Tokenizer, AutoModelForCausalLM, AutoTokenizer, pipeline
3
  import gradio as gr
4
  import requests
 
21
  translation_model = MBartForConditionalGeneration.from_pretrained(translation_model_name)
22
 
23
  # Load a text generation model from Hugging Face
24
+ text_generation_model_name = "EleutherAI/gpt-neo-2.7B" # You can switch to "EleutherAI/gpt-j-6B" if available
25
  text_tokenizer = AutoTokenizer.from_pretrained(text_generation_model_name)
26
  text_model = AutoModelForCausalLM.from_pretrained(text_generation_model_name, device_map="auto", torch_dtype=torch.float32)
27