Shushmita commited on
Commit
4a69777
·
verified ·
1 Parent(s): adeb12d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -22
app.py CHANGED
@@ -1,28 +1,28 @@
1
- import streamlit as st
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
 
4
- @st.cache_resource()
5
- def load_model():
6
- model_name = "bigcode/starcoder"
7
- tokenizer = AutoTokenizer.from_pretrained(model_name)
8
- model = AutoModelForCausalLM.from_pretrained(model_name)
9
- return model, tokenizer
10
 
11
- model, tokenizer = load_model()
12
 
13
- st.title("CodeCorrect AI")
14
- st.subheader("AI-powered Code Autocorrect Tool")
15
 
16
- code_input = st.text_area("Enter your code here:", height=200)
17
 
18
- if st.button("Correct Code"):
19
- if code_input.strip():
20
- prompt = f"### Fix the following code:\n{code_input}\n### Corrected version:\n"
21
- inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=512)
22
- outputs = model.generate(**inputs, max_length=512)
23
- corrected_code = tokenizer.decode(outputs[0], skip_special_tokens=True)
24
- st.text_area("Corrected Code:", corrected_code, height=200)
25
- else:
26
- st.warning("Please enter some code.")
27
 
28
- st.markdown("Powered by Hugging Face 🤗")
 
1
+ import streamlit as st
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
 
4
+ @st.cache_resource()
5
+ def load_model():
6
+ model_name = "bigcode/starcoder"
7
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
8
+ model = AutoModelForCausalLM.from_pretrained(model_name)
9
+ return model, tokenizer
10
 
11
+ model, tokenizer = load_model()
12
 
13
+ st.title("CodeCorrect AI")
14
+ st.subheader("AI-powered Code Autocorrect Tool")
15
 
16
+ code_input = st.text_area("Enter your code here:", height=200)
17
 
18
+ if st.button("Correct Code"):
19
+ if code_input.strip():
20
+ prompt = f"### Fix the following code:\n{code_input}\n### Corrected version:\n"
21
+ inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=512)
22
+ outputs = model.generate(**inputs, max_length=512)
23
+ corrected_code = tokenizer.decode(outputs[0], skip_special_tokens=True)
24
+ st.text_area("Corrected Code:", corrected_code, height=200)
25
+ else:
26
+ st.warning("Please enter some code.")
27
 
28
+ st.markdown("Powered by Hugging Face 🤗")