Shushmita commited on
Commit
8a098bb
·
verified ·
1 Parent(s): e1f5b53

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -8
app.py CHANGED
@@ -1,12 +1,15 @@
1
  import streamlit as st
2
- from transformers import pipeline
3
 
4
- #to load the model pipeline
5
  @st.cache_resource()
6
  def load_model():
7
- return pipeline("text2text-generation", model="Salesforce/codet5p-220m")
 
 
 
8
 
9
- model = load_model()
10
 
11
  # Streamlit UI
12
  st.title("CodeCorrect AI")
@@ -16,11 +19,13 @@ code_input = st.text_area("Enter your code here:", height=200)
16
 
17
  if st.button("Correct Code"):
18
  if code_input.strip():
19
- response = model(code_input, max_length=512)
20
- corrected_code = response[0]['generated_text']
 
 
 
21
  st.text_area("Corrected Code:", corrected_code, height=200)
22
  else:
23
  st.warning("Please enter some code.")
24
 
25
-
26
-
 
1
  import streamlit as st
2
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
3
 
4
+ # Load model and tokenizer
5
  @st.cache_resource()
6
  def load_model():
7
+ model_name = "Salesforce/codet5-small" # A better model for code correction
8
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
9
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
10
+ return model, tokenizer
11
 
12
+ model, tokenizer = load_model()
13
 
14
  # Streamlit UI
15
  st.title("CodeCorrect AI")
 
19
 
20
  if st.button("Correct Code"):
21
  if code_input.strip():
22
+ # Tokenize and generate corrected code
23
+ inputs = tokenizer(code_input, return_tensors="pt", padding=True, truncation=True, max_length=512)
24
+ outputs = model.generate(**inputs, max_length=512)
25
+ corrected_code = tokenizer.decode(outputs[0], skip_special_tokens=True)
26
+
27
  st.text_area("Corrected Code:", corrected_code, height=200)
28
  else:
29
  st.warning("Please enter some code.")
30
 
31
+ st.markdown("Powered by Hugging Face 🤗")