pyresearch commited on
Commit
9d27c6c
·
verified ·
1 Parent(s): 4a40fb3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -16
app.py CHANGED
@@ -2,32 +2,29 @@ import streamlit as st
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
 
5
- # Load model and tokenizer
6
- tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
7
- model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", trust_remote_code=True)
8
 
9
- # Move model to CPU
10
- model.to("cpu")
11
 
12
- # Streamlit app
13
- st.title("Text Generation with Transformers")
14
 
15
- # Input prompt
16
- prompt = st.text_input("Enter your prompt:")
 
17
 
18
- # Generate button
19
  if st.button("Generate"):
20
  with torch.no_grad():
21
- # Tokenize and generate output
22
- token_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
23
  output_ids = model.generate(
24
- token_ids.to(model.device),
25
  max_new_tokens=512,
26
  do_sample=True,
27
  temperature=0.1
28
  )
29
-
30
- # Decode and display the generated text
31
  generated_text = tokenizer.decode(output_ids[0][token_ids.size(1):])
32
  st.text("Generated Text:")
33
- st.text(generated_text)
 
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
 
5
+ # Use GPU if available
6
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
7
 
8
+ st.title("Text Generation with Hugging Face Transformers")
 
9
 
10
+ # Input prompt from user
11
+ prompt = st.text_area("Enter a prompt:", "this news is real pyresearch given right computer vision videos?")
12
 
13
+ # Load model and tokenizer
14
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
15
+ model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", torch_dtype="auto", device=device, trust_remote_code=True)
16
 
17
+ # Generate text on button click
18
  if st.button("Generate"):
19
  with torch.no_grad():
20
+ token_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt").to(device)
 
21
  output_ids = model.generate(
22
+ token_ids,
23
  max_new_tokens=512,
24
  do_sample=True,
25
  temperature=0.1
26
  )
27
+
 
28
  generated_text = tokenizer.decode(output_ids[0][token_ids.size(1):])
29
  st.text("Generated Text:")
30
+ st.write(generated_text)