import gradio as gr from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.metrics.pairwise import cosine_similarity from nltk import word_tokenize from nltk.stem import WordNetLemmatizer from nltk.corpus import stopwords import nltk import json # Download NLTK resources nltk.download('punkt') nltk.download('wordnet') nltk.download('stopwords') def preprocess(sentence): lemmatizer = WordNetLemmatizer() stop_words = set(stopwords.words('english')) tokens = word_tokenize(sentence.lower()) tokens = [lemmatizer.lemmatize(word) for word in tokens if word.isalnum()] tokens = [word for word in tokens if word not in stop_words] return ' '.join(tokens) def find_most_similar(sentence, candidates, threshold=0.15): input_bits = preprocess(sentence) chunks = [preprocess(candidate) for candidate in candidates] vectorizer = TfidfVectorizer() vectors = vectorizer.fit_transform([input_bits] + chunks) similarity_scores = cosine_similarity(vectors[0:1], vectors[1:]).flatten() similar_sentences = [] for i, score in enumerate(similarity_scores): if score >= threshold: similar_sentences.append({"sentence": candidates[i], "similarity_score": round(score, 4)}) return similar_sentences def read_sentences_from_file(file_location): with open(file_location, 'r') as file: text = file.read().replace('\n', ' ') sentences = [sentence.strip() for sentence in text.split('.') if sentence.strip()] return sentences def fetch_vectors(file, sentence): file_location = file.name chunks = read_sentences_from_file(file_location) similar_sentences = find_most_similar(sentence, chunks, threshold=0.15) return json.dumps(similar_sentences, indent=4) # Interface file_uploader = gr.File(label="Upload a .txt file") text_input = gr.Textbox(label="Enter a sentence") output_text = gr.Textbox(label="Similar Sentences JSON") iface = gr.Interface( fn=fetch_vectors, inputs=[file_uploader, text_input], outputs=output_text, title="Simple RAG - For QA", description="Upload a text file and enter the question. The threshold is set to 0.15." ) iface.launch(debug=True)