sabahat-shakeel commited on
Commit
9d3a107
·
verified ·
1 Parent(s): da5ffae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -2
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import streamlit as st
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
3
 
4
  # Load model and tokenizer
5
  @st.cache_resource
@@ -34,9 +35,13 @@ def generate_prompt(comment):
34
  def get_response(comment):
35
  prompt = generate_prompt(comment)
36
  inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
 
 
 
 
37
  outputs = model.generate(
38
- input_ids=inputs["input_ids"].to("cuda"),
39
- attention_mask=inputs["attention_mask"].to("cuda"),
40
  max_new_tokens=140,
41
  pad_token_id=tokenizer.pad_token_id # Ensure padding is handled properly
42
  )
 
1
  import streamlit as st
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
 
5
  # Load model and tokenizer
6
  @st.cache_resource
 
35
  def get_response(comment):
36
  prompt = generate_prompt(comment)
37
  inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
38
+
39
+ # Check if CUDA is available, otherwise use CPU
40
+ device = "cuda" if torch.cuda.is_available() else "cpu"
41
+
42
  outputs = model.generate(
43
+ input_ids=inputs["input_ids"].to(device),
44
+ attention_mask=inputs["attention_mask"].to(device),
45
  max_new_tokens=140,
46
  pad_token_id=tokenizer.pad_token_id # Ensure padding is handled properly
47
  )