gokilashree commited on
Commit
98aa4c2
·
verified ·
1 Parent(s): 3bb4e50

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -3
app.py CHANGED
@@ -7,7 +7,7 @@ from PIL import Image
7
  import os
8
 
9
  # Set up the Hugging Face API key from environment variables
10
- hf_api_key = os.getenv("new_hf_token")
11
  if not hf_api_key:
12
  raise ValueError("Hugging Face API key not found! Please set the 'HF_API_KEY' environment variable.")
13
  headers = {"Authorization": f"Bearer {hf_api_key}"}
@@ -20,10 +20,14 @@ translation_model_name = "facebook/mbart-large-50-many-to-one-mmt"
20
  tokenizer = MBart50Tokenizer.from_pretrained(translation_model_name)
21
  translation_model = MBartForConditionalGeneration.from_pretrained(translation_model_name)
22
 
23
- # Load a text generation model from Hugging Face
24
  text_generation_model_name = "EleutherAI/gpt-neo-2.7B" # You can switch to "EleutherAI/gpt-j-6B" if available
25
  text_tokenizer = AutoTokenizer.from_pretrained(text_generation_model_name)
26
- text_model = AutoModelForCausalLM.from_pretrained(text_generation_model_name, device_map="auto", torch_dtype=torch.float32)
 
 
 
 
27
 
28
  # Create a pipeline for text generation
29
  text_generator = pipeline("text-generation", model=text_model, tokenizer=text_tokenizer)
 
7
  import os
8
 
9
  # Set up the Hugging Face API key from environment variables
10
+ hf_api_key = os.getenv("HF_API_KEY")
11
  if not hf_api_key:
12
  raise ValueError("Hugging Face API key not found! Please set the 'HF_API_KEY' environment variable.")
13
  headers = {"Authorization": f"Bearer {hf_api_key}"}
 
20
  tokenizer = MBart50Tokenizer.from_pretrained(translation_model_name)
21
  translation_model = MBartForConditionalGeneration.from_pretrained(translation_model_name)
22
 
23
+ # Load a text generation model from Hugging Face using accelerate for memory optimization
24
  text_generation_model_name = "EleutherAI/gpt-neo-2.7B" # You can switch to "EleutherAI/gpt-j-6B" if available
25
  text_tokenizer = AutoTokenizer.from_pretrained(text_generation_model_name)
26
+ text_model = AutoModelForCausalLM.from_pretrained(
27
+ text_generation_model_name,
28
+ device_map="auto", # Automatically allocate model layers to devices (requires accelerate)
29
+ torch_dtype=torch.float32 # Specify dtype to optimize memory usage
30
+ )
31
 
32
  # Create a pipeline for text generation
33
  text_generator = pipeline("text-generation", model=text_model, tokenizer=text_tokenizer)