usmanyousaf's picture
Create app.py
917a8c3 verified
raw
history blame
4.32 kB
import os
from fastapi import FastAPI, File, Form, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from groq import Groq
import io
# Set up the Groq client
os.environ["GROQ_API_KEY"] = os.getenv("GROQ_API_KEY")
client = Groq(api_key=os.environ["GROQ_API_KEY"])
app = FastAPI()
# Pydantic model for the transcription result
class TranscriptionResponse(BaseModel):
transcription: str
class FeedbackResponse(BaseModel):
grammar_feedback: str
vocabulary_feedback: str
grammar_score: int
vocabulary_score: int
@app.get("/")
async def index():
return {"message": "Welcome to the Audio Transcription API!"}
@app.post("/transcribe")
async def transcribe_audio(audio_data: bytes = File(...), language: str = Form(...)):
try:
# Transcribe the audio based on the selected language
transcription = client.audio.transcriptions.create(
file=("audio.wav", audio_data),
model="whisper-large-v3",
prompt="Transcribe the audio accurately based on the selected language.",
response_format="text",
language=language,
)
return JSONResponse(content=TranscriptionResponse(transcription=transcription).dict())
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/check_grammar")
async def check_grammar(transcription: str = Form(...), language: str = Form(...)):
if not transcription or not language:
raise HTTPException(status_code=400, detail="Missing transcription or language selection")
try:
# Grammar check
grammar_prompt = (
f"Briefly check the grammar of the following text in {language}: {transcription}. "
"Identify any word that does not belong to the selected language and flag it. Based on the number of incorrect words, "
"also check the grammar deeply and carefully. Provide a score from 1 to 10 based on the grammar accuracy, reducing points for incorrect words, "
"and make sure to output the score on a new line after two line breaks like \"SCORE=\"."
)
grammar_check_response = client.chat.completions.create(
model="llama3-groq-70b-8192-tool-use-preview",
messages=[{"role": "user", "content": grammar_prompt}]
)
grammar_feedback = grammar_check_response.choices[0].message.content.strip()
# Vocabulary check
vocabulary_prompt = (
f"Check the vocabulary accuracy of the following text in {language}: {transcription}. "
"Identify any word that does not belong to the selected language and flag it. Based on the number of incorrect words, "
"also check the grammar deeply and carefully. Provide a score from 1 to 10 based on the vocabulary accuracy, reducing points for incorrect words, "
"and make sure to output the score on a new line after two line breaks like \"SCORE=\"."
)
vocabulary_check_response = client.chat.completions.create(
model="llama-3.1-70b-versatile",
messages=[{"role": "user", "content": vocabulary_prompt}]
)
vocabulary_feedback = vocabulary_check_response.choices[0].message.content.strip()
# Calculate scores from feedback
grammar_score = calculate_score(grammar_feedback)
vocabulary_score = calculate_score(vocabulary_feedback)
return JSONResponse(content=FeedbackResponse(
grammar_feedback=grammar_feedback,
vocabulary_feedback=vocabulary_feedback,
grammar_score=grammar_score,
vocabulary_score=vocabulary_score
).dict())
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
def calculate_score(feedback):
"""
Calculate score based on feedback content.
This function searches for the keyword 'SCORE=' or similar variations
(SCORE:, score:, etc.) and extracts the score value.
"""
import re
match = re.search(r'(SCORE=|score=|SCORE:|score:|SCORE = )\s*(\d+)', feedback)
if match:
# Extract and return the score as an integer
return int(match.group(2))
# Return a default score of 0 if no score is found in the feedback
return 0